In [None]:
# run this cell if langsmith is used to trace the chain
import os
from helper import get_api_key

os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = get_api_key(1)
os.environ['LANGCHAIN_PROJECT'] = 'default'
os.environ["LANGCHAIN_TRACING_V2"] = "true"

In [1]:
from typing import List

from langchain_community.llms.llamacpp import LlamaCpp
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings

class LlamaCppEmbeddings_(LlamaCppEmbeddings):
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents using the Llama model.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        embeddings = [self.client.embed(text)[0] for text in texts]
        return [list(map(float, e)) for e in embeddings]

    def embed_query(self, text: str) -> List[float]:
        """Embed a query using the Llama model.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        embedding = self.client.embed(text)[0]
        return list(map(float, embedding))

In [2]:
# using local installed llama_cpp_python
import multiprocessing

from langchain_community.chat_models import ChatLlamaCpp

local_model = "./models/Mistral-Nemo-Instruct-2407.Q5_K_S.gguf"

llm = ChatLlamaCpp(
    temperature=0.3,
    model_path=local_model,
    n_ctx=5000,
    n_gpu_layers=100,
    n_batch=1000,  # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
    max_tokens=768,
    n_threads=multiprocessing.cpu_count() - 1,
    repeat_penalty=1,
    top_p=0.8
)

#Use local model for embedding
embeddings = LlamaCppEmbeddings_(model_path=local_model)


llama_model_loader: loaded meta data with 32 key-value pairs and 363 tensors from ./models/Mistral-Nemo-Instruct-2407.Q5_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Models
llama_model_loader: - kv   3:                         general.size_label str              = 12B
llama_model_loader: - kv   4:                            general.license str              = apache-2.0
llama_model_loader: - kv   5:                          general.languages arr[str,9]       = ["en", "fr", "de", "es", "it", "pt", ...
llama_model_loader: - kv   6:                          llama.block_count u32              = 40
llama_model_loader: - k

In [None]:
# using locally served llama_cpp with openai wrapper
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

llm = ChatOpenAI(
    openai_api_base="http://llama_server:8080/",
    temperature=0.3,
)

import openai
client = openai.OpenAI(
    base_url="http://llama_server:8080/", # "http://<Your api-server IP>:port"
    api_key = "no_key"
)


In [3]:
# get the SQL utilities from langchain
from langchain_community.utilities import SQLDatabase


db = SQLDatabase.from_uri("mysql://user:password@mysql_db:3306/classicmodels")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM products LIMIT 5;")

mysql
['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']


"[('S10_1678', '1969 Harley Davidson Ultimate Chopper', 'Motorcycles', '1:10', 'Min Lin Diecast', 'This replica features working kickstand, front suspension, gear-shift lever, footbrake lever, drive chain, wheels and steering. All parts are particularly delicate due to their precise scale and require special care and attention.', 7933, Decimal('48.81'), Decimal('95.70')), ('S10_1949', '1952 Alpine Renault 1300', 'Classic Cars', '1:10', 'Classic Metal Creations', 'Turnable front wheels; steering function; detailed interior; detailed engine; opening hood; opening trunk; opening doors; and detailed chassis.', 7305, Decimal('98.58'), Decimal('214.30')), ('S10_2016', '1996 Moto Guzzi 1100i', 'Motorcycles', '1:10', 'Highway 66 Mini Classics', 'Official Moto Guzzi logos and insignias, saddle bags located on side of motorcycle, detailed engine, working steering, working suspension, two leather seats, luggage rack, dual exhaust pipes, small saddle bag located on handle bars, two-tone paint with

In [None]:
db.run("SELECT buyPrice from products where productName = '1969 Harley Davidson Ultimate Chopper';")

In [None]:
db.run("SELECT firstName, lastName, jobTitle FROM employees WHERE jobTitle LIKE '%Marketing%';")

In [4]:
from langchain.chains import create_sql_query_chain

generate_query = create_sql_query_chain(llm,db)
query = generate_query.invoke({"question": "What is the price of '1969 Harley Davidson Ultimate Chopper'?"})
print(query)



llama_print_timings:        load time =     861.23 ms
llama_print_timings:      sample time =       0.90 ms /    14 runs   (    0.06 ms per token, 15555.56 tokens per second)
llama_print_timings: prompt eval time =    2295.43 ms /  2939 tokens (    0.78 ms per token,  1280.37 tokens per second)
llama_print_timings:        eval time =     315.47 ms /    13 runs   (   24.27 ms per token,    41.21 tokens per second)
llama_print_timings:       total time =    2632.12 ms /  2952 tokens


SELECT price
FROM products
WHERE productCode =01 =


In [6]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
execute_query.invoke(query)

''

In [6]:
test_chain = generate_query | execute_query
result = test_chain.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     533.01 ms
llama_print_timings:      sample time =       2.18 ms /    73 runs   (    0.03 ms per token, 33501.61 tokens per second)
llama_print_timings: prompt eval time =      49.95 ms /    30 tokens (    1.66 ms per token,   600.64 tokens per second)
llama_print_timings:        eval time =    1351.18 ms /    72 runs   (   18.77 ms per token,    53.29 tokens per second)
llama_print_timings:       total time =    1455.42 ms /   102 tokens
Error closing cursor
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sqlalchemy/engine/base.py", line 2213, in _safe_close_cursor
    cursor.close()
  File "/usr/local/lib/python3.10/dist-packages/MySQLdb/cursors.py", line 83, in close
    while self.nextset():
  File "/usr/local/lib/python3.10/dist-packages/MySQLdb/cursors.py", line 137, in nextset
    nr = db.next_result()
MySQLdb.ProgrammingError: (1064, "You have an error in your SQL syn

In [None]:
print(test_chain.get_prompts()[0].pretty_print())

In [8]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# create first chain, where the template prompt get passed to the llm, the output is extract as str with output parser
rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

#query = generate_query.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})
#print(query)
#result = execute_query.invoke(query)
#print(result)


chain.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     533.01 ms
llama_print_timings:      sample time =       2.11 ms /    73 runs   (    0.03 ms per token, 34548.04 tokens per second)
llama_print_timings: prompt eval time =    1779.20 ms /  3266 tokens (    0.54 ms per token,  1835.66 tokens per second)
llama_print_timings:        eval time =    1343.44 ms /    72 runs   (   18.66 ms per token,    53.59 tokens per second)
llama_print_timings:       total time =    3169.33 ms /  3338 tokens
Error closing cursor
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sqlalchemy/engine/base.py", line 2213, in _safe_close_cursor
    cursor.close()
  File "/usr/local/lib/python3.10/dist-packages/MySQLdb/cursors.py", line 83, in close
    while self.nextset():
  File "/usr/local/lib/python3.10/dist-packages/MySQLdb/cursors.py", line 137, in nextset
    nr = db.next_result()
MySQLdb.ProgrammingError: (1064, "You have an error in your SQL syn

"The price of the '1969 Harley Davidson Ultimate Chopper' is 48.81."

In [None]:
print(chain.get_prompts()[0].pretty_print())

In [None]:
print(chain.get_prompts()[1].pretty_print())

In [10]:
# get a list of example prompts from sql_examples.json
# read json file
import json
with open("sql_examples.json", "r") as read_file:
    examples = json.load(read_file)

In [11]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

# create example selector using vector search
vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many employees we have?"})


llama_print_timings:        load time =     668.50 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     661.29 ms /     8 tokens (   82.66 ms per token,    12.10 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     668.82 ms /     9 tokens

llama_print_timings:        load time =     668.50 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     733.25 ms /    12 tokens (   61.10 ms per token,    16.37 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     739.33 ms /    13 tokens

llama_print_timings:     

[{'input': 'Get the details of the employee with employee number 1002.',
  'query': 'SELECT * FROM employees WHERE employeeNumber = 1002;'},
 {'input': 'Get the contact details of customers who have a credit limit greater than 100,000.',
  'query': 'SELECT customerName, contactLastName, contactFirstName, phone FROM customers WHERE creditLimit > 100000;'}]

In [12]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="How many products are there?"))



llama_print_timings:        load time =     668.50 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =     909.35 ms /     7 tokens (  129.91 ms per token,     7.70 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =     915.05 ms /     8 tokens


Human: Get the details of the employee with employee number 1002.
SQLQuery:
AI: SELECT * FROM employees WHERE employeeNumber = 1002;
Human: Get the contact details of customers who have a credit limit greater than 100,000.
SQLQuery:
AI: SELECT customerName, contactLastName, contactFirstName, phone FROM customers WHERE creditLimit > 100000;


In [None]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",table_info="some table info"))

In [18]:
prompt = ChatPromptTemplate.from_messages([
    ("system", table_details_prompt),
    ("user", "{input}")
])

In [19]:
testing = prompt | llm
testing.invoke({"input": "give me details of customer and what they have orders"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     533.01 ms
llama_print_timings:      sample time =      11.38 ms /   389 runs   (    0.03 ms per token, 34191.79 tokens per second)
llama_print_timings: prompt eval time =      43.53 ms /    10 tokens (    4.35 ms per token,   229.75 tokens per second)
llama_print_timings:        eval time =    6083.83 ms /   388 runs   (   15.68 ms per token,    63.78 tokens per second)
llama_print_timings:       total time =    6421.09 ms /   398 tokens


AIMessage(content="To provide details of a customer and their orders, we might need to query the following tables:\n\n1. customers: This table contains customer information, such as contact details, address, credit limit, and sales representative.\n2. orders: This table stores order information, including order dates, required dates, shipped dates, status, and customer numbers.\n3. orderdetails: This table records order details, including product codes, quantities, prices, and line numbers.\n\nTo get the required information, you can use a SQL query like this:\n\n```sql\nSELECT \n    customers.customer_name, \n    customers.contact_name, \n    customers.address_line1, \n    customers.city, \n    customers.postal_code, \n    customers.country, \n    customers.credit_limit, \n    customers.sales_representative, \n    orders.order_date, \n    orders.required_date, \n    orders.shipped_date, \n    orders.status, \n    orderdetails.product_code, \n    orderdetails.quantity, \n    orderdetai

In [13]:
# a function to read a csv file containing table name and table description as polars dataframe, then output each as string of text with fromat table name: xxx \n table description: xxx
import polars as pl
def get_table_info(file_path: str):
    df = pl.read_csv(file_path)
    table_info = ""
    for i in range(len(df)):
        table_info += f"Table Name: {df['Table Name'][i]}\nTable Description: {df['Description'][i]}\n\n"
    return table_info

In [14]:
table_info = get_table_info("tables_description.csv")
print(table_info)

Table Name: customers
Table Description: Stores customer information, including contact details, address, credit limit, and sales representative.

Table Name: employees
Table Description: Contains employee data, such as names, extensions, email addresses, job titles, and office codes.

Table Name: offices
Table Description: Holds office information, including city, phone number, address, state, country, and postal code.

Table Name: orderdetails
Table Description: Records order details, including product codes, quantities, prices, and line numbers.

Table Name: orders
Table Description: Stores order information, including order dates, required dates, shipped dates, status, and customer numbers.

Table Name: payments
Table Description: Tracks payment details, including check numbers, payment dates, and amounts.

Table Name: productlines
Table Description: Describes product lines, with text and HTML descriptions, and optional images.

Table Name: products
Table Description: Contains prod

In [15]:
from langchain_core.pydantic_v1 import BaseModel, Field
class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")

In [16]:
table_details_prompt = f"""Return the only the names of ALL the SQL tables that MIGHT be relevant to the user question using the information in Table Descritption. The tables are: \

{table_info}
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed. Only output the PORTENTIALLY RELEVANT Table Names and nothing else. Here is the user questions:"""

In [7]:
print(table_details_prompt)

Return the only the names of ALL the tables that MIGHT be relevant to the user question in json format. The tables are:

Table Name: customers
Table Description: Stores customer information, including contact details, address, credit limit, and sales representative.

Table Name: employees
Table Description: Contains employee data, such as names, extensions, email addresses, job titles, and office codes.

Table Name: offices
Table Description: Holds office information, including city, phone number, address, state, country, and postal code.

Table Name: orderdetails
Table Description: Records order details, including product codes, quantities, prices, and line numbers.

Table Name: orders
Table Description: Stores order information, including order dates, required dates, shipped dates, status, and customer numbers.

Table Name: payments
Table Description: Tracks payment details, including check numbers, payment dates, and amounts.

Table Name: productlines
Table Description: Describes pr

In [17]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_info}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)
tables = table_chain.invoke({"input": "give me details of customer and their orders"})
tables

  warn_deprecated(
Llama.generate: prefix-match hit

llama_print_timings:        load time =     533.01 ms
llama_print_timings:      sample time =       8.64 ms /   292 runs   (    0.03 ms per token, 33788.47 tokens per second)
llama_print_timings: prompt eval time =     219.15 ms /   315 tokens (    0.70 ms per token,  1437.36 tokens per second)
llama_print_timings:        eval time =    4486.26 ms /   291 runs   (   15.42 ms per token,    64.86 tokens per second)
llama_print_timings:       total time =    4907.80 ms /   606 tokens


[]

In [None]:
table_chain.get_prompts()[0]

In [None]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

tables=list_tables_tool.invoke("")
print(tables)
# split the string by ", "
tables = tables.split(", ")

for table in tables:
    print(get_schema_tool.invoke(table))




In [None]:
for tool in tools:
    print(tool.name, tool.description)