# SQL Chain (RAG) in Langchain

## SQL

In [1]:
from langchain_community.llms import Ollama 

from langchain.memory import ConversationBufferMemory
from langchain.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableParallel
from langchain_openai import ChatOpenAI
from langchain import hub
from langchain_openai.embeddings import OpenAIEmbeddings
import os
import pandas as pd


In [9]:
obj = hub.pull("rlm/rag-prompt")
sql = hub.pull("rlm/text-to-sql")

In [16]:
print(obj.messages[0].prompt.template)

"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\nQuestion: {question} \nContext: {context} \nAnswer:"

In [19]:
print(sql.messages[0].prompt.template)

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQL queries that corrsespond to questions are:

{few_shot_examples}

Question: {input}


In [33]:
sys = """
ou are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today". Remember to add the schema name before the table name when generating query, such as from test.table1.

"""

In [36]:
tem1 = """
You are an assistant for giving text to sql task.
Based on the 
- table schema: {schema}
- sql documentation: {docs}
- sql similar examples: {examples}
- database specific information: {db_info}
Write a SQL query that would answer the user's question.
If you don't know the answer, just say that you don't know.

Question: {question} \n
SQL Query:
"""

In [37]:
tem2 = """
You are an assistant for giving text to sql task.
Based on the 
- table schema: {schema}
- sql documentation: {docs}
- sql similar examples: {examples}

Write a SQL query that would answer the user's question.
If you don't know the answer, just say that you don't know.

Question: {question} \n
SQL Query:
"""

In [20]:
model = Ollama(model = 'codellama')
llm = ChatOpenAI()

In [21]:
database_uri = f"redshift+psycopg2://{os.environ['redshift_user']}:{os.environ['redshift_pass']}@redshift-cluster-comp0087-demo.cvliubs5oipw.eu-west-2.redshift.amazonaws.com:5439/comp0087"
db = SQLDatabase.from_uri(database_uri=database_uri, schema = 'test')

def get_schema(_):
    return db.get_table_info()


def run_query(query):
    return db.run(query)


## Reteriver

1. docs 
2. example
3. additional
4. table specifi

In [2]:
embedd = OpenAIEmbeddings()

In [3]:
# docs
from langchain_community.document_loaders import PyPDFLoader

docs = PyPDFLoader('../../data/sql_sheet.pdf')
pages = docs.load_and_split()

In [4]:
from langchain_community.vectorstores import FAISS

faiss_index = FAISS.from_documents(pages, embedd)


In [5]:
docs2 = faiss_index.similarity_search("what is inner join?", k=2)
for doc in docs2:
    print(str(doc.metadata["page"]) + ":", doc.page_content[:300])

0: CITY COUNTRY
country_idid name name id
6
7 Vatican CityVatican City
5
10 11 Monaco Monaco 10
NATURAL JOIN used these columns to match rows:
city.id, city.name, country.id, country.name.
NATURAL JOIN is very rarely used in practice.
Try out the interactive SQL Basics course at LearnSQL.com, and check
0: SQL Bas ics Cheat Sheet
SQL, or Structured Query Language, is a language to talk to
databases. It allows you to select specific data and to build
complex reports. Today, SQL is a universal language of data. It is
used in practically all technologies that process data.
SAMPLE DATA
COUNTRY
idname popu


In [9]:
docs_reteriver = faiss_index.as_retriever(search_type="similarity", search_kwargs={"k": 1})

In [62]:
type(docs_reteriver)

langchain_core.vectorstores.VectorStoreRetriever

In [10]:
docs_reteriver.invoke('what is inner join?')

[Document(page_content='CITY COUNTRY\ncountry_idid name name id\n6\n7 Vatican CityVatican City\n5\n10 11 Monaco Monaco 10\nNATURAL JOIN used these columns to match rows:\ncity.id, city.name, country.id, country.name.\nNATURAL JOIN is very rarely used in practice.\nTry out the interactive SQL Basics course at LearnSQL.com, and check out our other SQL courses.LearnSQL.com is owned by Vertabelo SA\nvertabelo.com | CC BY-NC-ND Vertabelo SA1 1\n2 2\n1 1\n2 2\n4 NULL\n1 1\n2 2\nNULL 31 1\n2 2\n4 NULL\nNULL 3\n1 1\n1 2\n2 1\n2 2\n6San Marino San Marino6\n7 7\n9 Greece Greece 9', metadata={'source': '../../data/sql_sheet.pdf', 'page': 0})]

In [11]:
## example 

example = pd.read_csv('../../data/example_pair.csv')
example

Unnamed: 0,question,query
0,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56
1,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD..."
2,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ..."
3,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i..."
4,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...
...,...,...
6995,What are all the company names that have a boo...,SELECT T1.company_name FROM culture_company AS...
6996,Show the movie titles and book titles for all ...,"SELECT T1.title , T3.book_title FROM movie AS..."
6997,What are the titles of movies and books corres...,"SELECT T1.title , T3.book_title FROM movie AS..."
6998,Show all company names with a movie directed i...,SELECT T2.company_name FROM movie AS T1 JOIN c...


In [13]:
from langchain_community.document_loaders.csv_loader import CSVLoader


loader = CSVLoader(file_path='../../data/example_pair.csv')
data = loader.load()

In [16]:
example_index = FAISS.from_documents(data, embedd)

In [23]:
example_reteriver = example_index.as_retriever()

## Template 1


In [25]:
prompt1 = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
        MessagesPlaceholder(variable_name="history"),
        ("human", tem1),
    ]
)

memory = ConversationBufferMemory(return_messages=True)



In [25]:
## define reterview


# docs retreival
### code here ###
retriever_docs = 'xxx'

# example retreival
retriever_example = 'xxx'

# Additional reterival
retriever_addi = 'xxx'

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

In [None]:
sql_chain = (
    RunnableParallel(
        {
            'schema': get_schema,
            'docs': retriever_docs | format_docs,
            'example': retriever_example | format_docs,
            'db_info': retriever_addi | format_docs,
            'question': RunnablePassthrough()
        }
    )
    | prompt1
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [None]:
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save

template_agent = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""  # noqa: E501
prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural "
            "language answer. No pre-amble.",
        ),
        ("human", template_agent),
    ]
)

class InputType(BaseModel):
    question: str

In [None]:
chain = (
    RunnablePassthrough.assign(query=sql_response_memory).with_types(
        input_type=InputType
    )
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

## Template 2

In [34]:
prompt2 = ChatPromptTemplate.from_messages(
    [
        ("system", sys),
        ("human", tem2),
    ]
)


In [59]:
sql_chain =  {
            'docs': docs_reteriver | format_docs,
            'examples': example_reteriver | format_docs,
            'question': RunnablePassthrough()
        }| RunnablePassthrough.assign(schema = get_schema) | prompt2 | llm.bind(stop=["\nSQLResult:"]) | StrOutputParser()


In [57]:
sql_chain =  {
            'docs': docs_reteriver | format_docs,
            'examples': example_reteriver | format_docs,
            'question': RunnablePassthrough()
        }| RunnablePassthrough.assign(schema = get_schema) | prompt2 | llm | StrOutputParser()


In [60]:
sql_chain.invoke('How many distinct teams in the football league')

'SELECT COUNT(DISTINCT hometeam) AS num_distinct_teams\nFROM test.football\nUNION\nSELECT COUNT(DISTINCT awayteam)\nFROM test.football;'

In [51]:
sql_chain =  {
            # 'docs': docs_reteriver,
            # 'example': example_reteriver,
            'question': RunnablePassthrough()
        }| RunnablePassthrough.assign(schema = get_schema) 

In [52]:
sql_chain.invoke({'question': 'How many distinct teams in the football league'})

{'question': {'question': 'How many distinct teams in the football league'},
 'schema': '\nCREATE TABLE test.football (\n\tseason VARCHAR(256), \n\tleague VARCHAR(256), \n\tdiv VARCHAR(256), \n\tdate DATE, \n\thometeam VARCHAR(256), \n\tawayteam VARCHAR(256), \n\tfthg REAL, \n\tftag REAL, \n\tftr VARCHAR(256), \n\ththg REAL, \n\thtag REAL, \n\thtr VARCHAR(256), \n\treferee VARCHAR(256), \n\ths REAL, \n\t"as" REAL, \n\thst REAL, \n\tast REAL, \n\thf REAL, \n\taf REAL, \n\thc REAL, \n\tac REAL, \n\thy REAL, \n\tay REAL, \n\thr REAL, \n\tar REAL, \n\tmatchname VARCHAR(256)\n)\n\n/*\n3 rows from football table:\nseason\tleague\tdiv\tdate\thometeam\tawayteam\tfthg\tftag\tftr\ththg\thtag\thtr\treferee\ths\tas\thst\tast\thf\taf\thc\tac\thy\tay\thr\tar\tmatchname\n2000-2001\tPremier League\tE0\t2000-08-19\tCharlton\tManchester City\t4.0\t0.0\tH\t2.0\t0.0\tH\tRob Harris\t17.0\t8.0\t14.0\t4.0\t13.0\t12.0\t6.0\t6.0\t1.0\t2.0\t0.0\t0.0\tCharlton VS Manchester City\n2000-2001\tPremier League\tE0\t2