# SQL-based chains
Implementation inspired from
> https://python.langchain.com/docs/use_cases/qa_structured/sql

which adopts best practices from 
> https://blog.langchain.dev/llms-and-sql/

In [1]:
#### Load db wrapper and llm

from langchain.utilities import SQLDatabase
from langchain.llms import GPT4All
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler


db_uri = "sqlite:///db/retention-sqlite.db"
model_path = "../privateGPT/models/ggml-model-gpt4all-falcon-q4_0.bin" 
VERBOSE = True


db = SQLDatabase.from_uri(
    db_uri,
    include_tables=['fact_retention_model'],
    sample_rows_in_table_info=2)

llm = GPT4All(model=model_path, max_tokens=1000, backend='gptj',
                          n_batch=8, callbacks=[StreamingStdOutCallbackHandler()], verbose=VERBOSE)

Found model file at  ../privateGPT/models/ggml-model-gpt4all-falcon-q4_0.bin
falcon_model_load: loading model from '../privateGPT/models/ggml-model-gpt4all-falcon-q4_0.bin' - please wait ...
falcon_model_load: n_vocab   = 65024
falcon_model_load: n_embd    = 4544
falcon_model_load: n_head    = 71
falcon_model_load: n_head_kv = 1
falcon_model_load: n_layer   = 32
falcon_model_load: ftype     = 2
falcon_model_load: qntvr     = 0
falcon_model_load: ggml ctx size = 3872.64 MB
falcon_model_load: memory_size =    32.00 MB, n_mem = 65536
falcon_model_load: ..

objc[14036]: Class GGMLMetalClass is implemented in both /opt/miniconda3/envs/gnn/lib/python3.11/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libreplit-mainline-metal.dylib (0x14e814228) and /opt/miniconda3/envs/gnn/lib/python3.11/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libllamamodel-mainline-metal.dylib (0x14e6dc228). One of the two will be used. Which one is undefined.


...................... done
falcon_model_load: model size =  3872.59 MB / num tensors = 196


In [2]:
#### Load retriever vectorstore for similar questions and queries

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from chromadb.config import Settings
import chromadb


persist_directory = "vectorstore/"
# A chinese-compatible model
# NOTE: Solve `No transformers found` by following instructions from
# https://huggingface.co/uer/sbert-base-chinese-nli/commit/827d1b828afae1e0fde9f26a590ff0e1eaded589#d2h-223453
embedding_model = "uer/sbert-base-chinese-nli" 
top_k = 2
CHROMA_SETTINGS = Settings(
    persist_directory=persist_directory,
    anonymized_telemetry=False
)


embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory)
vectordb = Chroma(
    persist_directory=persist_directory, 
    embedding_function=embeddings, 
    client_settings=CHROMA_SETTINGS, 
    client=chroma_client
)

retriever = vectordb.as_retriever(search_kwargs={"k": top_k})

In [3]:
#### embeddings prompt template 

from typing import Any, List
from langchain.prompts import PromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema import format_document


_specific_document_prompt_template = PromptTemplate(input_variables=["page_content"], template="`{page_content}`")
_specific_prompt_template = PromptTemplate(
    input_variables=["context"],
    template=
"""Some examples of SQL queries that correspond to questions are listed below, and are enclosed in ``:

{context}
"""
)


def specific_stuff_combine_docs(
    question: str,
    document_prompt: BasePromptTemplate = _specific_document_prompt_template,
    document_separator: str = "\n\n",
    prompt: BasePromptTemplate = _specific_prompt_template,
    document_variable_name: str = "context",
    input_variables: List[str]= ["context"],
    verbose: bool = False,
    **kwargs: Any,
    ) -> str:
    """
    Retrieve relevant docs from vectorstore, then
    concatenate docs together without going into a LLM
    """
    
    docs = retriever.get_relevant_documents(question)
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    # NOTE: not valiedate as in StuffDocumentChain:78 get_default_document_variable_name
    inputs = {
        k: v
        for k, v in kwargs.items()
        if k in input_variables
    }
    inputs[document_variable_name] = document_separator.join(doc_strings)
    formatted = prompt.format(**inputs)
    if verbose:
        print(formatted)
    return formatted

---
## RunnableSequence
First approach, from *create_sql_query_chain*
> Be careful while using this approach as it is susceptible to SQL Injection

separate invoke and run to migitate.

In [6]:
#### RunnableSequence, no memory
import time
from langchain.chains import create_sql_query_chain
from langchain.schema.output_parser import StrOutputParser


## NOTE:
# This chain automates finding table infos
# which we can improve by annotating from an agnet
TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

{few_shot_examples}.

Question: {input}"""

CUSTOM_PROMPT = PromptTemplate(
    input_variables=["dialect", "top_k", "table_info", "few_shot_examples", "input"], template=TEMPLATE
)

chain = create_sql_query_chain(llm=llm, db=db, prompt=CUSTOM_PROMPT)

# Create chain with LangChain Expression Language
while True:
    question = input("\nEnter a question: ")
    if question == "exit":
        break
    if question.strip() == "":
        continue
    
    start = time.time()
    # NOTE: rlm example
    # Refer to https://smith.langchain.com/hub/rlm/text-to-sql
    inputs = {
        "table_info": lambda _: db.get_table_info(),
        "input": lambda x : x["question"] + "\nSQLQuery: ",
        "few_shot_examples": lambda x: specific_stuff_combine_docs(x["question"], verbose=True),
        "dialect":  lambda _: db.dialect,
        "top_k": lambda _ : 5,
    }
    runnableMap = (
        inputs
        | CUSTOM_PROMPT
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
    )
    # sql_response = chain.invoke(inputs)
    sql_response = runnableMap.invoke({"question": question})
    end = time.time()


    # Print the result
    print("\n\n> Question:")
    print(question)
    print(f"\n> Answer (took {round(end - start, 2)} s.):")
    print(sql_response)



Some examples of SQL queries that correspond to questions are listed below, and are enclosed in ``:

`Question: 上月末公募基金的保有规模是多少
Query: select sum(asset) from fact_retention_model where raisetype='公募基金' and cdate=(DATE_FORMAT(CURRENT_DATE(), '%Y-%m-01')-INTERVAL 1 DAY)`

`Question: 上月末权益类基金产品的保有规模,保有客户数是多少
Query: select sum(asset),count(distinct reten_syscustomerid) from fact_retention_model where product_type_bi='权益类' and cdate=(DATE_FORMAT(CURRENT_DATE(), '%Y-%m-01')-INTERVAL 1 DAY)`

`SELECT SUM(asset) FROM fact_retention_model WHERE raisetype='公募基金' AND cdate=(DATE_FORMAT(CURRENT_DATE(), '%Y-%m-01')-INTERVAL 1 DAY)`

Question: 上月末槟募基金产品的保有规模,保有客户数是多少
SQLQuery: `SELECT SUM(asset), COUNT(DISTINCT retin_syscustomerid) FROM fact_retention_model WHERE product_type_bi='槟募基金' AND cdate=(DATE_FORMAT(CURRENT_DATE(), '%Y-%m-01')-INTERVAL 1 DAY)`

Question: 上月末槟募基金产品的份项槟募基金
SQLQuery: `SELECT SUM(asset) FROM fact_retention_model WHERE raiset

> Question:
上月末公募基金的保有规模是多少

> Answer (took 165.31 s

---
## SQLDataBaseChain
Second approach
> the performance of the SQLDatabaseChain can be enhanced in several ways: </br>
* Adding sample rows
* Specifying custom table information 
* Using Query Checker self-correct invalid SQL using parameter use_query_checker=True
* Customizing the LLM Prompt include specific instructions or relevant information, using parameter prompt=CUSTOM_PROMPT
* Get intermediate steps access the SQL statement as well as the final result using parameter return_intermediate_steps=True 
* Limit the number of rows a query will return using parameter top_k=5

---
## SQLDatabaseSequentialChain
Thrid approch, added benefits on top of second approach
> Determining which tables to use based on the user question </br>
> Calling the normal SQL database chain using only relevant tables

In [None]:
from langchain_experimental.sql import SQLDatabaseSequentialChain
