# SQL-based chains
Implementation inspired from:
> https://python.langchain.com/docs/use_cases/qa_structured/sql

which adopts best practices from:
> https://blog.langchain.dev/llms-and-sql/

In [None]:
#### Load retriever vectorstore for similar questions and queries

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from chromadb.config import Settings
import chromadb


# A chinese-compatible embedding model
# NOTE: Solve `No transformers found` by following instructions from
# https://huggingface.co/uer/sbert-base-chinese-nli/commit/827d1b828afae1e0fde9f26a590ff0e1eaded589#d2h-223453
embedding_model = "uer/sbert-base-chinese-nli" 
persist_directory = "../vectorstore/"
score_threshold = 0.8
top_k = 3
CHROMA_SETTINGS = Settings(
    persist_directory=persist_directory,
    anonymized_telemetry=False
)


embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS, path=persist_directory)
vectordb = Chroma(
    persist_directory=persist_directory, 
    embedding_function=embeddings, 
    client_settings=CHROMA_SETTINGS, 
    client=chroma_client
)


# enforce threshold, don't return on irrelevant search
retriever = vectordb.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        "k": top_k,
        "score_threshold": score_threshold
    }
)

In [None]:
#### Load db wrapper and llm

from langchain.utilities import SQLDatabase
from langchain.llms import GPT4All
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler


db_uri = "sqlite:///../db/retention-sqlite.db"
model_path = "../models/ggml-model-gpt4all-falcon-q4_0.bin"


db = SQLDatabase.from_uri(
    db_uri,
    include_tables=['fact_retention_model'],
    sample_rows_in_table_info=2)

llm = GPT4All(
    model=model_path, 
    max_tokens=1000, 
    backend='gptj',
    n_batch=8, 
    callbacks=[StreamingStdOutCallbackHandler()], 
    verbose=True
)

In [None]:
#### embeddings prompt template 

from typing import Any, List
from langchain.prompts import PromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema import format_document
# from langchain.chains import StuffDocumentsChain
import json


_specific_question_and_query_file = "../downloads/retention/question_query.json"
_specific_question_and_query_dict = {}
_specific_document_prompt_template = PromptTemplate(input_variables=["page_content"], template="{page_content}")
_specific_prompt_template = PromptTemplate(
    input_variables=["context"],
    template=
"""Some examples of SQL queries that correspond to questions are listed below:

{context}
"""
)


with open(_specific_question_and_query_file, 'r') as file:
    _specific_question_and_query_dict = json.load(file)


def _get_specific_question_and_query(question: str) -> str:
    if question is not None:
        try:
            return f"Question: {question}\nQuery: {_specific_question_and_query_dict[question]}"
        except KeyError as ke:
            print(f"No such key as {question}")
            return question
    else:
        print("Key is none")
        return question


def specific_stuff_combine_docs(
    question: str,
    document_prompt: BasePromptTemplate = _specific_document_prompt_template,
    document_separator: str = "\n\n",
    prompt: BasePromptTemplate = _specific_prompt_template,
    document_variable_name: str = "context",
    input_variables: List[str]= ["context"],
    verbose: bool = False,
    **kwargs: Any,
    ) -> str:
    """
    Retrieve relevant docs from vectorstore, then
    concatenate docs together without going into a LLM
    """
    
    docs = retriever.get_relevant_documents(question)
    if len(docs) == 0:
        return ""
    doc_strings = [_get_specific_question_and_query(format_document(doc, document_prompt)) for doc in docs]
    # NOTE: no validation as in StuffDocumentChain:78 get_default_document_variable_name
    inputs = {
        k: v
        for k, v in kwargs.items()
        if k in input_variables
    }
    inputs[document_variable_name] = document_separator.join(doc_strings)
    formatted = prompt.format(**inputs)
    if verbose:
        print(formatted)
    return formatted

---
## RunnableSequence
First approach, from *create_sql_query_chain*
> Be careful while using this approach as it is susceptible to SQL Injection

separate invoke and run to migitate.

In [None]:
#### RunnableSequence, no memory

import time
from langchain.chains import create_sql_query_chain
from langchain.schema.output_parser import StrOutputParser


## NOTE:
# This chain automates finding table infos
# which we can improve by annotating from an agnet
TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

{few_shot_examples}.

Question: {input}"""

CUSTOM_PROMPT = PromptTemplate(
    input_variables=["dialect", "top_k", "table_info", "few_shot_examples", "input"], template=TEMPLATE
)

_chain = create_sql_query_chain(llm=llm, db=db, prompt=CUSTOM_PROMPT)

# Create chain with LangChain Expression Language
while True:
    question = input("\nEnter a question: ")
    if question == "exit":
        break
    if question.strip() == "":
        continue
    
    start = time.time()
    # NOTE: rlm example
    # Refer to https://smith.langchain.com/hub/rlm/text-to-sql
    inputs = {
        "table_info": lambda _: db.get_table_info(),
        "input": lambda x : x["question"] + "\nSQLQuery: ",
        "few_shot_examples": lambda x: specific_stuff_combine_docs(x["question"]),
        "dialect":  lambda _: db.dialect,
        "top_k": lambda _ : 5,
    }
    # equivalent to _chain, but more flexible
    runnableMap = (
        inputs
        | CUSTOM_PROMPT
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
    )
    sql_response = runnableMap.invoke({"question": question})
    end = time.time()

    # Print the result
    print("\n\n> Question:")
    print(question)
    print(f"\n> Answer (took {round(end - start, 2)} s.):")
    print(sql_response)



---
## SQLDataBaseChain with Memory
Second approach, basically is my RunnableMap + SQL execution + Memory.
> the performance of the SQLDatabaseChain can be enhanced in several ways: </br>
* Adding sample rows
* Specifying custom table information 
* Using Query Checker self-correct invalid SQL using parameter use_query_checker=True
* Customizing the LLM Prompt include specific instructions or relevant information, using parameter prompt=CUSTOM_PROMPT
* Get intermediate steps access the SQL statement as well as the final result using parameter return_intermediate_steps=True 
* Limit the number of rows a query will return using parameter top_k=5

In [None]:
#### Custom prompt for LLMChain in SQLDatabaseChain

_sqlite_prompt = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

"""

_prompt_suffix = """Only use the following tables:
{table_info}
{few_shot_examples}

Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)

Question: {input}"""

CUSTOM_PROMPT_WITH_FEW_SHOT_EXAMPLES_AND_MEMORY = PromptTemplate(
    input_variables=["input", "table_info", "top_k", "few_shot_examples", "history"],
    template=_sqlite_prompt + _prompt_suffix,
)

In [None]:
#### FIXME: Not working

from typing import Any, Dict, Optional, Union
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chains.llm import LLMChain
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from langchain.memory import ConversationBufferMemory
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.schema.language_model import BaseLanguageModel


class LLMChainWithMemory(LLMChain):
    def prep_inputs(self, inputs: Dict[str, Any] | Any) -> Dict[str, str]:
        if not isinstance(inputs, dict):
            _input_keys = set(self.input_keys)
            if self.memory is not None:
                _input_keys = _input_keys.difference(self.memory.memory_variables)
            if len(_input_keys) != 1:
                raise ValueError(
                    f"A single string input was passed in, but this chain expects "
                    f"multiple inputs ({_input_keys}). When a chain expects "
                    f"multiple inputs, please call it by passing in a dictionary, "
                    "eg `chain({'foo': 1, 'bar': 2})`"
                )
            inputs = {list(_input_keys)[0]: inputs}
        
        external_context = self.memory.load_memory_variables(inputs) if self.memory is not None else {"history" : []}
        inputs = dict(inputs, **external_context)
        inputs["few_shot_examples"] = specific_stuff_combine_docs(inputs["input"].strip("\nSQLQuery:"))
        self._validate_inputs(inputs)
        return inputs




SQLDatabaseChain_kwargs = {    
    "return_sql": False, # Will return sql-command directly without executing it
    "return_intermediate_steps": False, # Whether or not to return the intermediate steps along with the final answer.
    # return_direct: False, # Whether or not to return the result of querying the SQL table directly.
    "use_query_checker": False, # Whether or not the query checker tool should be used to attempt to fix the initial SQL from the LLM.
    # query_checker_prompt: None # The prompt template that should be used by the query checker
}

memory = ConversationBufferMemory(input_key='input', memory_key="history")
chain = SQLDatabaseChain(
    llm_chain=LLMChainWithMemory(
        llm=llm,     
        memory=memory,    
        prompt=CUSTOM_PROMPT_WITH_FEW_SHOT_EXAMPLES_AND_MEMORY, # for llm_chain, which is the actual prediction
    ),
    database=db,
    **SQLDatabaseChain_kwargs
)

In [None]:
chain({"query": "查询上月末产品类型为投顾类基金保有规模和人数" })

In [None]:
from pprint import pprint
pprint(chain.llm_chain.memory.load_memory_variables({}))

In [None]:
chain({"query": "这其中招商银行的保有规模是多少" })

---
## SQLDatabaseSequentialChain
Third approch, added benefits on top of second approach. Not flexible enough as Agent.
> Determining which tables to use based on the user question </br>
> Calling the normal SQL database chain using only relevant tables

In [None]:
from langchain_experimental.sql import SQLDatabaseSequentialChain