# env setup

In [1]:
import os
LANGCHAIN_API_KEY = os.getenv(key="LANGCHAIN_API_KEY")
LANGCHAIN_ENDPOINT = os.getenv(key="LANGCHAIN_ENDPOINT")
LANGCHAIN_TRACING_V2 = os.getenv(key="LANGCHAIN_TRACING_V2")
from dotenv import load_dotenv
load_dotenv()
groq_api_key = os.environ['GROQ_API_KEY']

# Imports

In [2]:
from langchain_community.chat_models import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferMemory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.pydantic_v1 import BaseModel

# define llms

In [4]:
phi_llm = ChatOllama(model='phi',temperature=0.1,timeout=300)
gemma_llm = ChatOllama(model='gemma:2b',temperature=0.1,timeout=300)
llm = phi_llm

# DB: Connect to a SQLite DB.

In [5]:
db = SQLDatabase.from_uri("sqlite:///nba_roster.db", sample_rows_in_table_info= 0)

In [6]:
# query the db
query = "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson'"
db.run(query)

"[('Golden State Warriors',)]"

In [7]:
query = "SELECT COUNT(DISTINCT Team) FROM nba_roster"
db.run(query)

'[(30,)]'

In [9]:
def get_schema(_):
    return db.get_table_info()

def run_query(query):
    return db.run(query)

# prompts and templates

In [39]:
# template = """Based on the table schema below, return a SQL query that would answer the user's question:
# {schema}

# Question: {question}
# SQL Query: """  # noqa: E501

template = """Based on the table schema below, return only a SQL query that would answer the user's question. No pre-amble.:
{schema}

Question: {question}
Answer:"""  # noqa: E501


prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
        MessagesPlaceholder(variable_name="history"),
        ("human", template),
    ]
)

In [40]:
print(prompt)

input_variables=['history', 'question', 'schema'] input_types={'history': typing.List[typing.Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]} messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='Given an input question, convert it to a SQL query. No pre-amble.')), MessagesPlaceholder(variable_name='history'), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question', 'schema'], template="Based on the table schema below, return only a SQL query that would answer the user's question. No pre-amble.:\n{schema}\n\nQuestion: {question}\nAnswer:"))]


In [42]:
memory = ConversationBufferMemory(return_messages=True)

# Chains

In [61]:
# Chain to query with memory

sql_chain = (
    RunnablePassthrough.assign(
        schema=get_schema,
        history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
    )
    | prompt
    # | llm.bind(stop=["\nUser:"])
    | gemma_llm.bind(stop=["\n"])
    # | llm.bind(stop=["\nUser:", "\nRules:", "\nAssistant"])
    # | llm
    | StrOutputParser()
)

In [62]:
def save(input_output):
    output = {"output": input_output.pop("output")}
    memory.save_context(input_output, output)
    return output["output"]

In [63]:
sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save

In [64]:
# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""  # noqa: E501

prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural "
            "language answer. No pre-amble.",
        ),
        ("human", template),
    ]
)

In [65]:
prompt_response

ChatPromptTemplate(input_variables=['query', 'question', 'response', 'schema'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='Given an input question and SQL response, convert it to a natural language answer. No pre-amble.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['query', 'question', 'response', 'schema'], template='Based on the table schema below, question, sql query, and sql response, write a natural language response:\n{schema}\n\nQuestion: {question}\nSQL Query: {query}\nSQL Response: {response}'))])

In [66]:
# Supply the input types to the prompt
class InputType(BaseModel):
    question: str

In [67]:
chain = (
    RunnablePassthrough.assign(query=sql_response_memory).with_types(
        input_type=InputType
    )
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | gemma_llm
)

In [68]:
chain.invoke({"question": "What team is Klay Thompson on?"})

AIMessage(content='Based on the provided table schema and responses, Klay Thompson is currently playing for the Golden State Warriors.', response_metadata={'model': 'gemma:2b', 'created_at': '2024-04-24T10:32:24.071621Z', 'message': {'role': 'assistant', 'content': ''}, 'done': True, 'total_duration': 3181013400, 'load_duration': 15497800, 'prompt_eval_count': 164, 'prompt_eval_duration': 2631502000, 'eval_count': 22, 'eval_duration': 526091000}, id='run-90d3c59a-e831-4994-bdb0-bd583bd59d62-0')

In [69]:
chain.invoke({"question": "Give me total number of players in Golden State Warriors Team?"})

AIMessage(content='Based on the provided table schema and SQL response, there are 17 players currently on the Golden State Warriors team.', response_metadata={'model': 'gemma:2b', 'created_at': '2024-04-24T10:33:24.9903513Z', 'message': {'role': 'assistant', 'content': ''}, 'done': True, 'total_duration': 3302144100, 'load_duration': 3242700, 'prompt_eval_count': 167, 'prompt_eval_duration': 2684497000, 'eval_count': 25, 'eval_duration': 610001000}, id='run-dd374060-2234-4ef8-af32-34a1d4dfc2ee-0')

In [70]:
chain.invoke({"question": "How many total different teams are there in database?"})

AIMessage(content='Based on the provided table schema and SQL response, there are 30 distinct teams in the NBA roster.', response_metadata={'model': 'gemma:2b', 'created_at': '2024-04-24T10:36:06.3236272Z', 'message': {'role': 'assistant', 'content': ''}, 'done': True, 'total_duration': 3243913700, 'load_duration': 3923700, 'prompt_eval_count': 162, 'prompt_eval_duration': 2683971000, 'eval_count': 23, 'eval_duration': 546693000}, id='run-474b3312-0940-43ee-9a89-7670b7e59c62-0')

In [71]:
chain.invoke({"question": "who is oldest player in nba roaster?"})

AssertionError: The input to RunnablePassthrough.assign() must be a dict.