In [None]:
# run this cell if langsmith is used to trace the chain
import os
from helper import get_api_key

os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = get_api_key(1)
os.environ['LANGCHAIN_PROJECT'] = 'default'
os.environ["LANGCHAIN_TRACING_V2"] = "true"

In [1]:
from typing import List

from langchain_community.llms.llamacpp import LlamaCpp
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings

class LlamaCppEmbeddings_(LlamaCppEmbeddings):
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents using the Llama model.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        embeddings = [self.client.embed(text)[0] for text in texts]
        return [list(map(float, e)) for e in embeddings]

    def embed_query(self, text: str) -> List[float]:
        """Embed a query using the Llama model.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        embedding = self.client.embed(text)[0]
        return list(map(float, embedding))


In [2]:
# using local installed llama_cpp_python
import multiprocessing

from langchain_community.chat_models import ChatLlamaCpp
""""""

local_model = "./models/xLAM-7b-fc-r.Q8_0.gguf"
local_embedding = "./models/Phi-3.1-mini-4k-instruct-Q6_K_L.gguf"

llm = ChatLlamaCpp(
    temperature=0.2,
    model_path=local_model,
    n_ctx=5000,
    n_gpu_layers=100,
    n_batch=1000,  # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
    max_tokens=512,
    n_threads=multiprocessing.cpu_count() - 1,
    repeat_penalty=1,
    top_p=0.8
)

llm2 = ChatLlamaCpp(
    temperature=0.2,
    model_path=local_embedding,
    n_ctx=5000,
    n_gpu_layers=100,
    n_batch=1000,  # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
    max_tokens=512,
    n_threads=multiprocessing.cpu_count() - 1,
    repeat_penalty=1,
    top_p=0.8
)

#Use local model for embedding
embeddings = LlamaCppEmbeddings_(model_path=local_embedding)


llama_model_loader: loaded meta data with 25 key-value pairs and 273 tensors from ./models/xLAM-7b-fc-r.Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = xlam-FC-7b
llama_model_loader: - kv   2:                          llama.block_count u32              = 30
llama_model_loader: - kv   3:                       llama.context_length u32              = 4096
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   7:              llama.attention.head_count_kv u32  

In [None]:
# using locally served llama_cpp with openai wrapper
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

llm = ChatOpenAI(
    openai_api_base="http://llama_server:8080/",
    temperature=0.3,
)

import openai
client = openai.OpenAI(
    base_url="http://llama_server:8080/", # "http://<Your api-server IP>:port"
    api_key = "no_key"
)


In [3]:
# get the SQL utilities from langchain
from langchain_community.utilities import SQLDatabase


db = SQLDatabase.from_uri("mysql://user:password@mysql_db:3306/classicmodels")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM products LIMIT 5;")

mysql
['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']


"[('S10_1678', '1969 Harley Davidson Ultimate Chopper', 'Motorcycles', '1:10', 'Min Lin Diecast', 'This replica features working kickstand, front suspension, gear-shift lever, footbrake lever, drive chain, wheels and steering. All parts are particularly delicate due to their precise scale and require special care and attention.', 7933, Decimal('48.81'), Decimal('95.70')), ('S10_1949', '1952 Alpine Renault 1300', 'Classic Cars', '1:10', 'Classic Metal Creations', 'Turnable front wheels; steering function; detailed interior; detailed engine; opening hood; opening trunk; opening doors; and detailed chassis.', 7305, Decimal('98.58'), Decimal('214.30')), ('S10_2016', '1996 Moto Guzzi 1100i', 'Motorcycles', '1:10', 'Highway 66 Mini Classics', 'Official Moto Guzzi logos and insignias, saddle bags located on side of motorcycle, detailed engine, working steering, working suspension, two leather seats, luggage rack, dual exhaust pipes, small saddle bag located on handle bars, two-tone paint with

In [None]:
db.run("SELECT buyPrice from products where productName = '1969 Harley Davidson Ultimate Chopper';")

In [None]:
db.run("SELECT firstName, lastName, jobTitle FROM employees WHERE jobTitle LIKE '%Marketing%';")

In [4]:
# More SQL tools from langchain
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

tables=list_tables_tool.invoke("")
print(tables)
# split the string by ", "
tables = tables.split(", ")

for table in tables:
    print(get_schema_tool.invoke(table))


customers, employees, offices, orderdetails, orders, payments, productlines, products

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmitt	Carine 	40.32.2555	54,

In [13]:
# First we can create the prompt template to achieve our task.abs
from langchain_core.prompts import PromptTemplate

query_template = """
Based on the table schema and relevant information below, write a SQL query that would answer the user's question.
{schema}

User's question: {input}
SQL Query:
"""

query_prompt = PromptTemplate.from_template(query_template)

In [6]:
def get_schema(_):
    return db.get_table_info()


In [14]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_query_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    |query_prompt
    |llm.bind(stop=[";"])
    |StrOutputParser()
)

In [15]:
query = sql_query_chain.invoke({"input":"What is the price of '1969 Harley Davidson Ultimate Chopper'?"})
print(query)

Llama.generate: prefix-match hit

llama_print_timings:        load time =     686.08 ms
llama_print_timings:      sample time =       1.44 ms /    28 runs   (    0.05 ms per token, 19512.20 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =     621.57 ms /    28 runs   (   22.20 ms per token,    45.05 tokens per second)
llama_print_timings:       total time =     649.53 ms /    28 tokens



SELECT `buyPrice`
FROM products
WHERE `productName` = '1969 Harley Davidson Ultimate Chopper'


In [16]:
def run_query(query):
    return db.run(query)

In [17]:
answer_prompt = PromptTemplate.from_template("""
Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {input}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

In [18]:
full_chain = (
    RunnablePassthrough.assign(query=sql_query_chain).assign(
        result=lambda vars: run_query(vars["query"]),
    )
    | answer_prompt
    | llm2
    | StrOutputParser()
)


In [19]:
full_chain.invoke({"input":"What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     686.08 ms
llama_print_timings:      sample time =       1.16 ms /    24 runs   (    0.05 ms per token, 20689.66 tokens per second)
llama_print_timings: prompt eval time =      38.05 ms /    21 tokens (    1.81 ms per token,   551.88 tokens per second)
llama_print_timings:        eval time =     491.64 ms /    23 runs   (   21.38 ms per token,    46.78 tokens per second)
llama_print_timings:       total time =     549.49 ms /    44 tokens

llama_print_timings:        load time =      94.95 ms
llama_print_timings:      sample time =       0.69 ms /    27 runs   (    0.03 ms per token, 39301.31 tokens per second)
llama_print_timings: prompt eval time =      94.77 ms /    96 tokens (    0.99 ms per token,  1012.94 tokens per second)
llama_print_timings:        eval time =     272.10 ms /    26 runs   (   10.47 ms per token,    95.55 tokens per second)
llama_print_timings:       total time =     381.07 ms /   122 

" The price of the '1969 Harley Davidson Ultimate Chopper' is $48.81."

In [20]:
# get a list of example prompts from sql_examples.json
# read json file
import json
with open("sql_examples.json", "r") as read_file:
    examples = json.load(read_file)

In [21]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

# create example selector using vector search
vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many employees we have?"})

  warn_deprecated(

llama_print_timings:        load time =     205.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     191.45 ms /     7 tokens (   27.35 ms per token,    36.56 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     206.20 ms /     8 tokens

llama_print_timings:        load time =     205.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     275.74 ms /    11 tokens (   25.07 ms per token,    39.89 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     291.86 ms /    12 tokens

llama_

[{'input': 'what is price of `1968 Ford Mustang`',
  'query': "SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;"},
 {'input': 'Get the highest payment amount made by any customer.',
  'query': 'SELECT MAX(amount) FROM payments;'}]

In [22]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input"],
)
print(few_shot_prompt.format(input="How many products are there?"))


llama_print_timings:        load time =     205.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     397.69 ms /     6 tokens (   66.28 ms per token,    15.09 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     412.71 ms /     7 tokens


Human: What is the total number of orders?
SQLQuery:
AI: SELECT COUNT(orderNumber) FROM orders;
Human: What is the total number of orders?
SQLQuery:
AI: SELECT COUNT(orderNumber) FROM orders;


In [23]:
query_template = """
Based on the table schema and relevant information below, write a SQL query that would answer the user's question.
{schema}

Below are a number of examples questions and corresponding SQL queries.
"""

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", query_template),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",schema="some table info"))


llama_print_timings:        load time =     205.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     219.15 ms /     6 tokens (   36.53 ms per token,    27.38 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     233.43 ms /     7 tokens


System: 
Based on the table schema and relevant information below, write a SQL query that would answer the user's question.
some table info

Below are a number of examples questions and corresponding SQL queries.

Human: What is the total number of orders?
SQLQuery:
AI: SELECT COUNT(orderNumber) FROM orders;
Human: What is the total number of orders?
SQLQuery:
AI: SELECT COUNT(orderNumber) FROM orders;
Human: How many products are there?


In [24]:
# a function to read a csv file containing table name and table description as polars dataframe, then output each as string of text with fromat table name: xxx \n table description: xxx
import polars as pl
def get_table_info(file_path:str):
    df = pl.read_csv(file_path)
    table_info = ""
    for i in range(len(df)):
        table_info += f"Table Name: {df['Table Name'][i]}\nTable Description: {df['Description'][i]}\n\n"
    return table_info

def table_info_runnable(_):
    return get_table_info("tables_description.csv")

def get_tables_as_list(tables_name: str) -> list:
    tables = tables_name.strip().split(", ")
    return tables

def get_schema_from_table(table_names:list):
    schema = ""
    useful_table = db.get_usable_table_names()
    for table in table_names:
        print(table)
        if table in useful_table:
            schema += get_schema_tool.invoke(table)
    return schema

In [34]:
table_info = get_table_info(None)
print(table_info)

Table Name: customers
Table Description: Stores customer information, including contact details, address, credit limit, and sales representative.

Table Name: employees
Table Description: Contains employee data, such as names, extensions, email addresses, job titles, and office codes.

Table Name: offices
Table Description: Holds office information, including city, phone number, address, state, country, and postal code.

Table Name: orderdetails
Table Description: Records order details, including product codes, quantities, prices, and line numbers.

Table Name: orders
Table Description: Stores order information, including order dates, required dates, shipped dates, status, and customer numbers.

Table Name: payments
Table Description: Tracks payment details, including check numbers, payment dates, and amounts.

Table Name: productlines
Table Description: Describes product lines, with text and HTML descriptions, and optional images.

Table Name: products
Table Description: Contains prod

In [25]:
table_details_prompt = PromptTemplate.from_template("""
Return only the names of ALL the tables that MIGHT be relevant to the user question using the information in Table Descritption. The tables are:
{table_info}

Here is the user question:{input}
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed. Only output the relevant table name and nothing else.

""")
                                                    

In [27]:
select_table_chain = (
    RunnablePassthrough.assign(table_info=table_info_runnable)
    | table_details_prompt
    | llm2.bind(stop=[";"])
    | StrOutputParser()
    | get_tables_as_list
    | get_schema_from_table
)


In [28]:
selected_list = select_table_chain.invoke({"input":"How many cutomers with order count more than 5"})
print(selected_list)

Llama.generate: prefix-match hit

llama_print_timings:        load time =      94.95 ms
llama_print_timings:      sample time =       0.11 ms /     4 runs   (    0.03 ms per token, 37037.04 tokens per second)
llama_print_timings: prompt eval time =     175.12 ms /   338 tokens (    0.52 ms per token,  1930.12 tokens per second)
llama_print_timings:        eval time =      34.86 ms /     3 runs   (   11.62 ms per token,    86.05 tokens per second)
llama_print_timings:       total time =     211.79 ms /   341 tokens


customers
orders

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmitt	Carine 	40.32.2555	54, rue Royale	None	Nantes	None	44000	France	1370	21000.00
112	Signal Gi

In [29]:
# now chain up the final_prompt which have few shot selection with schema output by select_table_chain

sql_query_chain = (
    RunnablePassthrough.assign(schema=select_table_chain)
    |final_prompt
    |llm.bind(stop=[";"])
    |StrOutputParser()
)

In [31]:
new_query = sql_query_chain.invoke({"input":"How many cutomers with order count more than 5"})
print(new_query)

Llama.generate: prefix-match hit

llama_print_timings:        load time =      94.95 ms
llama_print_timings:      sample time =       0.13 ms /     4 runs   (    0.03 ms per token, 29850.75 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =      58.99 ms /     4 runs   (   14.75 ms per token,    67.81 tokens per second)
llama_print_timings:       total time =      62.43 ms /     4 tokens


customers
orders



llama_print_timings:        load time =     205.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     558.93 ms /    12 tokens (   46.58 ms per token,    21.47 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     574.73 ms /    13 tokens
Llama.generate: prefix-match hit

llama_print_timings:        load time =     686.08 ms
llama_print_timings:      sample time =       2.29 ms /    47 runs   (    0.05 ms per token, 20568.93 tokens per second)
llama_print_timings: prompt eval time =     219.47 ms /   411 tokens (    0.53 ms per token,  1872.69 tokens per second)
llama_print_timings:        eval time =     827.16 ms /    46 runs   (   17.98 ms per token,    55.61 tokens per second)
llama_print_timings:       total time =    1105.47 ms /   457 


SELECT COUNT(customerNumber) 
FROM (
  SELECT customerNumber 
  FROM orders 
  GROUP BY customerNumber 
  HAVING COUNT(orderNumber) > 5
) AS subquery;


In [32]:
few_shots_full_chain = (
    RunnablePassthrough.assign(query=sql_query_chain).assign(
        result=lambda vars: run_query(vars["query"]),
    )
    | answer_prompt
    | llm2
    | StrOutputParser()
)

In [33]:
few_shots_full_chain.invoke({"input":"How many cutomers with order count more than 5"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =      94.95 ms
llama_print_timings:      sample time =       0.10 ms /     4 runs   (    0.02 ms per token, 40404.04 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =      46.47 ms /     4 runs   (   11.62 ms per token,    86.08 tokens per second)
llama_print_timings:       total time =      48.26 ms /     4 tokens


customers
orders



llama_print_timings:        load time =     205.98 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     280.00 ms /    12 tokens (   23.33 ms per token,    42.86 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     294.61 ms /    13 tokens
Llama.generate: prefix-match hit

llama_print_timings:        load time =     686.08 ms
llama_print_timings:      sample time =       1.72 ms /    36 runs   (    0.05 ms per token, 20881.67 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =     646.11 ms /    36 runs   (   17.95 ms per token,    55.72 tokens per second)
llama_print_timings:       total time =     667.28 ms /    36 

' There is 1 customer with an order count of more than 5.'

In [22]:
# langchain demo implementation

from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

generate_query = create_sql_query_chain(llm,db)
query = generate_query.invoke({"question": "How many orders are from France?"})
print(query)

execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke(query)
print(result)

print("langchain's prompt: \n")
print(generate_query.get_prompts()[0].pretty_print())

Llama.generate: prefix-match hit

llama_print_timings:        load time =     589.95 ms
llama_print_timings:      sample time =       1.61 ms /    34 runs   (    0.05 ms per token, 21170.61 tokens per second)
llama_print_timings: prompt eval time =    1324.06 ms /  2845 tokens (    0.47 ms per token,  2148.70 tokens per second)
llama_print_timings:        eval time =     697.19 ms /    33 runs   (   21.13 ms per token,    47.33 tokens per second)
llama_print_timings:       total time =    2041.66 ms /  2878 tokens


SELECT COUNT(*) as total_orders
FROM orders
JOIN customers ON orders.customerNumber = customers.customerNumber
WHERE customers.country = 'France';
[(37,)]
You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get th

In [21]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# create first chain, where the template prompt get passed to the llm, the output is extract as str with output parser
rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

#query = generate_query.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})
#print(query)
#result = execute_query.invoke(query)
#print(result)


chain.invoke({"question": "How many orders are from France?"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     589.95 ms
llama_print_timings:      sample time =       1.65 ms /    34 runs   (    0.05 ms per token, 20593.58 tokens per second)
llama_print_timings: prompt eval time =      73.94 ms /    16 tokens (    4.62 ms per token,   216.40 tokens per second)
llama_print_timings:        eval time =     698.71 ms /    33 runs   (   21.17 ms per token,    47.23 tokens per second)
llama_print_timings:       total time =     798.85 ms /    49 tokens
Llama.generate: prefix-match hit

llama_print_timings:        load time =     589.95 ms
llama_print_timings:      sample time =       0.53 ms /    11 runs   (    0.05 ms per token, 20912.55 tokens per second)
llama_print_timings: prompt eval time =      54.37 ms /    83 tokens (    0.66 ms per token,  1526.63 tokens per second)
llama_print_timings:        eval time =     165.67 ms /    10 runs   (   16.57 ms per token,    60.36 tokens per second)
llama_print_timings:       to

'\nThere are 37 orders from France.'

In [None]:
print(chain.get_prompts()[0].pretty_print())

In [None]:
print(chain.get_prompts()[1].pretty_print())

Return the only the names of ALL the tables that MIGHT be relevant to the user question in json format. The tables are:

Table Name: customers
Table Description: Stores customer information, including contact details, address, credit limit, and sales representative.

Table Name: employees
Table Description: Contains employee data, such as names, extensions, email addresses, job titles, and office codes.

Table Name: offices
Table Description: Holds office information, including city, phone number, address, state, country, and postal code.

Table Name: orderdetails
Table Description: Records order details, including product codes, quantities, prices, and line numbers.

Table Name: orders
Table Description: Stores order information, including order dates, required dates, shipped dates, status, and customer numbers.

Table Name: payments
Table Description: Tracks payment details, including check numbers, payment dates, and amounts.

Table Name: productlines
Table Description: Describes pr

In [37]:
select_table_prompt = ChatPromptTemplate.from_messages([
    ("system", table_details_prompt),
    ("user", "{input}")
])

select_llm = select_table_prompt | llm2 | StrOutputParser()
select_llm.invoke({"input": "How many cutomers with order count more than 5"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     218.59 ms
llama_print_timings:      sample time =       1.58 ms /    58 runs   (    0.03 ms per token, 36662.45 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =     620.45 ms /    58 runs   (   10.70 ms per token,    93.48 tokens per second)
llama_print_timings:       total time =     645.05 ms /    58 tokens


" To answer the user's question, we need to consider tables that contain customer information and order details. The relevant tables are:\n\n1. customers\n2. orders\n\nThese tables contain the necessary information to determine the number of customers with more than 5 orders."

In [36]:
select_table = {"input": itemgetter("question")} | select_llm | get_tables
select_table.invoke({"question": "How many cutomers with order count more than 5"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     218.59 ms
llama_print_timings:      sample time =       1.65 ms /    63 runs   (    0.03 ms per token, 38228.16 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =     661.68 ms /    63 runs   (   10.50 ms per token,    95.21 tokens per second)
llama_print_timings:       total time =     685.27 ms /    63 tokens


["To answer the user's question",
 'we need to consider tables that store customer information and order details. The relevant tables are:\n\n1. customers\n2. orderdetails\n3. orders\n\nThese tables contain the necessary information to determine the number of customers with more than 5 orders.']

In [35]:
chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "How many cutomers with order count more than 5"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     218.59 ms
llama_print_timings:      sample time =       1.53 ms /    61 runs   (    0.03 ms per token, 39791.26 tokens per second)
llama_print_timings: prompt eval time =      30.91 ms /    14 tokens (    2.21 ms per token,   452.91 tokens per second)
llama_print_timings:        eval time =     634.53 ms /    60 runs   (   10.58 ms per token,    94.56 tokens per second)
llama_print_timings:       total time =     683.96 ms /    74 tokens


ValueError: table_names {'we need to consider tables that store customer information and order details. The relevant tables are:\n\n1. customers\n2. orderdetails\n3. orders\n\nThese tables contain the necessary information to determine the number of customers with more than 5 orders.', 'To answer this user question'} not found in database

In [None]:
table_chain.get_prompts()[0]

In [None]:
for tool in tools:
    print(tool.name, tool.description)