## Chain
1. Take natural language query from
2. Translate (using code LLM) the natural language query from #1 into SQL command
3. Run the query from #2 against the db

#### Import project dependencies

In [None]:
from dotenv import load_dotenv
# from langchain_community.utilities import SQLDatabase
from langchain.sql_database import SQLDatabase
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from functools import partial

import langchain
langchain.verbose=True

import sqlglot
import os

#### Pull database into project

In [2]:
db=SQLDatabase.from_uri("sqlite:///nba_stats.db")
print(db.get_table_info())


CREATE TABLE player_stats (
	"index" INTEGER, 
	year TEXT, 
	"Season_type" TEXT, 
	"PLAYER_ID" INTEGER, 
	"RANK" INTEGER, 
	"PLAYER" TEXT, 
	"TEAM_ID" INTEGER, 
	"TEAM" TEXT, 
	"GP" INTEGER, 
	"MIN" INTEGER, 
	"FGM" INTEGER, 
	"FGA" INTEGER, 
	"FG_PCT" REAL, 
	"FG3M" INTEGER, 
	"FG3A" INTEGER, 
	"FG3_PCT" REAL, 
	"FTM" INTEGER, 
	"FTA" INTEGER, 
	"FT_PCT" REAL, 
	"OREB" INTEGER, 
	"DREB" INTEGER, 
	"REB" INTEGER, 
	"AST" INTEGER, 
	"STL" INTEGER, 
	"BLK" INTEGER, 
	"TOV" INTEGER, 
	"PF" INTEGER, 
	"PTS" INTEGER, 
	"AST_TOV" REAL, 
	"STL_TOV" REAL
)

/*
3 rows from player_stats table:
index	year	Season_type	PLAYER_ID	RANK	PLAYER	TEAM_ID	TEAM	GP	MIN	FGM	FGA	FG_PCT	FG3M	FG3A	FG3_PCT	FTM	FTA	FT_PCT	OREB	DREB	REB	AST	STL	BLK	TOV	PF	PTS	AST_TOV	STL_TOV
0	2012-13	Regular_Season	201142	1	Kevin Durant	1610612760	OKC	81	3119	731	1433	0.51	139	334	0.416	679	750	0.905	46	594	640	374	116	105	280	143	2280	1.34	0.41
1	2012-13	Regular_Season	977	2	Kobe Bryant	1610612747	LAL	78	3013	738	1595	0.463	132

### Create functions that will be used as chain inputs

In [None]:
def get_schema(_):
    '''
    Input: This function does not require a input '_'
    Output: Database info about db tables used as schema
    '''
    return db.get_table_info()

In [4]:
def run_sql(query: str, dialect: str) -> str:
    """
    Transpile and execute an SQL query using the specified dialect.
    Parameters:
    - query (str): The SQL query to be transpiled and executed.
    - dialect (str): The target SQL dialect for transpilation and execution.
    Returns:
    str: The result of executing the transpiled SQL query.
    Example:
    ```
    result = run_sql("SELECT * FROM my_table", "sqlite")
    print(result)
    ```
    """
    query=sqlglot.transpile(query, read="postgres", write=dialect)[0]
    print(query)
    return db.run(query)

In [None]:
# from langchain_nvidia_ai_endpoints import ChatNVIDIA

In [None]:
load_dotenv('.env')
api_key = os.getenv("NV_API_KEY")

### Create a instance of ChatNVIDIA model name -- `mistralai/mamba-codestral-7b-v0.1`

In [9]:
sql_llm=ChatNVIDIA(
    model="mistralai/mamba-codestral-7b-v0.1",
    max_tokens=500,
    api_key = api_key
    
)

### Prepare SQL Prompt Template {input: question, schema}

In [10]:
SQL_GEN_TEMPLATE="""
### Task
Generate a SQLite query to answer, but use the LIMIT clause to limit the output to 5 results. Always use ILIKE on TEXT columns [QUESTION]{question}[/QUESTION]
### Instructions
- Always provide all the columns, If you cannot answer the question with the available database schema, return 'I do not know'
### Database Schema
The query will run on a database with the following schema:
{schema}
### Answer
Given the database schema, here is the SQLite query that answers [QUESTION]{question}[/QUESTION]
[SUFFIX][/SQL][PREFIX]▁[SQL]\n
"""

### Create and invoke chain - save results in variable `response`

In [None]:
prompt=PromptTemplate.from_template(SQL_GEN_TEMPLATE)
sql_chain=(
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | sql_llm.bind(stop=["[/SQL]", ";"])
    | StrOutputParser()
)
response=sql_chain.invoke({"question": "How does Kobe Bryant's shooting percentage compare to everyone's average shooting percentage?"})
print(response)

SELECT AVG(FG_PCT) as AVG_FG_PCT, FG_PCT as Kobe_FG_PCT, PLAYER
FROM player_stats
WHERE PLAYER LIKE '%Kobe Bryant%'
LIMIT 5


In [13]:
run_sql(response, 'sqlite')

SELECT AVG(FG_PCT) AS AVG_FG_PCT, FG_PCT AS Kobe_FG_PCT, PLAYER FROM player_stats WHERE PLAYER LIKE '%Kobe Bryant%' LIMIT 5


"[(0.40475000000000005, 0.463, 'Kobe Bryant')]"

### Create a instance of ChatNVIDIA model name -- `meta/llama3-70b-instruct`

In [14]:
chat_llm=ChatNVIDIA(
    model = "meta/llama3-70b-instruct",
    max_tokens=1000
)



### Create a `SystemMessageTemplate` and a `HumanMessageTemplate` to use for the `ChatPromptTemplate`

In [15]:
SQL_SUM_SYS_MSG=(
    "Given an input question and SQL response, convert it to a natural language answer."
    " Give a human like response"
)
SQL_SUM_TEMPLATE="""
Based on the table schema below, question, sql query, and sql response,\
 write a natural language response. Be generic and avoid errors as much as possible.
{schema}
Question: {question}
SQL Query: {sqlquery}
SQL Response: {response}
""".strip()

### Create a new chain `full_chain`, passing the previous outcome from `sql_chain` in as the first input
This is a crucial step where the model uses a RunnablePassthrough.assign tool to pass through the response from the sql_chain as the first input to the full_chain. 
This is where the natural language input is converted to SQL code to retrieve relevant information and return the 

In [None]:
prompt_response=ChatPromptTemplate.from_messages([
    ("system", SQL_SUM_SYS_MSG),
    ("human", SQL_SUM_TEMPLATE),
])
history=[]
def save_return(d, history=[]):
    history+=[d]
    return d
full_chain=(
    RunnablePassthrough.assign(sqlquery=sql_chain)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: run_sql(x["sqlquery"], "sqlite") or "ERROR RETRIEVING FROM DATABASE",
    )
    | prompt_response
    | partial(save_return, history=history)
    | chat_llm
    | StrOutputParser()
)

In [17]:
print(sql_response.invoke({'question': 'how do i improve my shooting percentage?'}))

SELECT "PLAYER", "FG_PCT" AS "Shooting_Pct"
FROM player_stats
WHERE "FG_PCT" < 0.5
LIMIT 5


In [19]:
print(full_chain.invoke({'question': 'how do i improve my shooting percentage?'}))

# 

SELECT "PLAYER", "FG_PCT" AS "Shooting_Pct" FROM player_stats WHERE "FG_PCT" < 0.5 LIMIT 5


Exception: [401] Unauthorized
Invalid JWT serialization: Missing dot delimiter(s)
Please check or regenerate your API key.

In [None]:
print(full_chain.invoke({'question': 'who is the best 3 point shooter?'}))

# 

In [None]:
import gradio as gr
from langchain.sql_database import SQLDatabase
from langchain.chat_models import ChatOpenAI  # Replace this with your Llama integration
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Function to process natural language query
def process_query(nl_query):
    try:
        # Step 1: Translate natural language query into SQL
        prompt_template = PromptTemplate(
            input_variables=["query"],
            template="Translate the following natural language query into SQL:\n\n{query}",
        )
        llm = ChatOpenAI(temperature=0)  # Replace with your Llama LLM initialization
        sql_query = llm.generate(prompt_template.format(query=nl_query))
        
        # Step 2: Execute SQL query on the database
        db = SQLDatabase.from_uri("sqlite:///example.db")  # Update with your database connection string
        results = db.run(sql_query)
        return sql_query, results
    except Exception as e:
        return f"Error generating SQL: {str(e)}", "No results"

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Natural Language to SQL Query with Llama LLM")
    nl_query_input = gr.Textbox(
        label="Enter your natural language query",
        placeholder="e.g., Show me all users who joined last month.",
        lines=2,
    )
    sql_output = gr.Textbox(label="Generated SQL Query", lines=2, interactive=False)
    query_results = gr.Textbox(label="Query Results", lines=5, interactive=False)
    submit_button = gr.Button("Run Query")
    submit_button.click(process_query, inputs=nl_query_input, outputs=[sql &#8203;:contentReference[oaicite:0]{index=0}&#8203;


SyntaxError: incomplete input (238439787.py, line 36)