<a href="https://colab.research.google.com/github/christinezuzart/LangChain-for-LLM/blob/main/LLaMA2ChatWithSQLWithReplicateAPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Packages**

In [1]:
! pip install langchain replicate

Collecting langchain
  Downloading langchain-0.0.346-py3-none-any.whl (2.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/2.0 MB[0m [31m2.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.8/2.0 MB[0m [31m12.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting replicate
  Downloading replicate-0.21.1-py3-none-any.whl (35 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.6.3-py3-none-any.whl (28 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain)
  Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)
Collecting langchain-core<0.1,>=0.0.10 (from langchain)
  Downloading langchain_core-0.0.10-py3-none-any.wh

In [2]:
# API
from langchain.llms import Replicate

In [3]:
import os
from google.colab import userdata

os.environ["REPLICATE_API_TOKEN"] = userdata.get('REPLICATE_API')

In [4]:
replicate_id = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
llama2_chat_replicate = Replicate(
    model=replicate_id, input={"temperature": 0.01, "max_length": 500, "top_p": 1}
)



In [5]:
# Simply set the LLM we want to use
llm = llama2_chat_replicate

**DB**

In [6]:
from langchain.utilities import SQLDatabase


In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
db = SQLDatabase.from_uri("sqlite:////content/drive/MyDrive/TalkToYourDB/chinook.db")

In [10]:
def get_schema(_):
    return db.get_table_info()

In [11]:
def run_query(query):
    return db.run(query)

**Query a SQL DB**

In [12]:
# Prompt
from langchain.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
        ("human", template),
    ]
)

In [14]:
# Chain to query
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)



In [15]:
sql_response.invoke({"question": "How many employees are there?"})

' SELECT COUNT(*) FROM employees;'

In [16]:
# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.",
        ),
        ("human", template),
    ]
)

full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

In [17]:
full_chain.invoke({"question": "How many employees are there?"})

' Sure! Here\'s the natural language answer based on the given SQL query and response:\n\n"There are 8 employees."'

In [18]:
full_chain.invoke({"question": "Who are the top 3 best selling artists?"})

OperationalError: ignored

**Chat with a SQL DB**
Next, we can add memory.

In [19]:
# Prompt
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

template = """Given an input question, convert it to a SQL query. No pre-amble. Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", template),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ]
)


In [20]:
memory = ConversationBufferMemory(return_messages=True)

In [21]:
# Chain to query with memory
from langchain.schema.runnable import RunnableLambda

sql_chain = (
    RunnablePassthrough.assign(
        schema=get_schema,
        history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
    )
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [23]:
def save(input_output):
    output = {"output": input_output.pop("output")}
    memory.save_context(input_output, output)
    return output["output"]

In [24]:
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save

In [26]:
sql_response_memory.invoke({"question": "Who are the top 3 best selling artists?"})

' AI:  SQL Query:\nSELECT ArtistId, Name FROM artists ORDER BY Sales DESC LIMIT 3;'

In [27]:
# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.",
        ),
        ("human", template),
    ]
)

full_chain = (
    RunnablePassthrough.assign(query=sql_response_memory)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)


In [28]:
full_chain.invoke({"question": "Who are the top 3 best selling artists?"})

OperationalError: ignored