<a href="https://colab.research.google.com/github/hanhanwu/Hanhan_LangGraph_Exercise/blob/main/AI_for_BI/sql_agent_with_validation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SQL Agent with SQL Validation

In [1]:
%%capture --no-stderr
%pip install -U langgraph langchain_openai langchain_community

In [3]:
from google.colab import userdata

# load the environment variables set in colab
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')

# About the Data

In [4]:
import requests

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"

response = requests.get(url)

if response.status_code == 200:
    # Open a local file in binary write mode
    with open("Chinook.db", "wb") as file:
        # Write the content of the response (the file) to the local file
        file.write(response.content)
    print("File downloaded and saved as Chinook.db")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

File downloaded and saved as Chinook.db


In [5]:
from langchain_community.utilities import SQLDatabase
import pandas as pd
import pprint
import json
import ast


db = SQLDatabase.from_uri("sqlite:///Chinook.db")  # load the downloaded DB
print(db.dialect)
all_table_names = db.get_usable_table_names()
print('All the Usable Tables:', all_table_names)
print()

table = 'Track'
print(f'Data Sample from Table {table}:')
sample = ast.literal_eval(db.run(f"SELECT * FROM {table} LIMIT 5;"))
columns_query = f"""SELECT name FROM pragma_table_info('{table}')"""
cols = db.run(columns_query)
sample_df = pd.DataFrame(sample, columns=[col[0] for col in ast.literal_eval(cols)])
display(sample_df)

sqlite
All the Usable Tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

Data Sample from Table Track:


Unnamed: 0,TrackId,Name,AlbumId,MediaTypeId,GenreId,Composer,Milliseconds,Bytes,UnitPrice
0,1,For Those About To Rock (We Salute You),1,1,1,"Angus Young, Malcolm Young, Brian Johnson",343719,11170334,0.99
1,2,Balls to the Wall,2,2,1,,342562,5510424,0.99
2,3,Fast As a Shark,3,2,1,"F. Baltes, S. Kaufman, U. Dirkscneider & W. Ho...",230619,3990994,0.99
3,4,Restless and Wild,3,2,1,"F. Baltes, R.A. Smith-Diesel, S. Kaufman, U. D...",252051,4331779,0.99
4,5,Princess of the Dawn,3,2,1,Deaffy & R.A. Smith-Diesel,375418,6290521,0.99


In [6]:
all_table_schemas = {}
schema_cols = ['cid', 'name', 'type', 'notnull', 'default_value', 'pk']
for table_name in all_table_names:
    # Get schema data
    schema_query = f"PRAGMA table_info({table_name});"
    schema_result = db.run(schema_query)
    schema_data = ast.literal_eval(schema_result)

    # Combine column names and schema data
    schema = []
    for i, data in enumerate(schema_data):
        schema.append(dict(zip(schema_cols, data)))
    all_table_schemas[table_name] = schema

for k, v in all_table_schemas.items():
  print(k)
  print(v)
  break

Album
[{'cid': 0, 'name': 'AlbumId', 'type': 'INTEGER', 'notnull': 1, 'default_value': None, 'pk': 1}, {'cid': 1, 'name': 'Title', 'type': 'NVARCHAR(160)', 'notnull': 1, 'default_value': None, 'pk': 0}, {'cid': 2, 'name': 'ArtistId', 'type': 'INTEGER', 'notnull': 1, 'default_value': None, 'pk': 0}]


# Build Agent

### The State of Agent

In [7]:
from typing import List
from typing import Annotated, Sequence
from typing_extensions import TypedDict

from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    success_execution: bool
    rewrite_ct: int

### Nodes & Edges

In [8]:
import re
from typing import Annotated, Literal, Sequence
from typing_extensions import TypedDict
from pydantic import BaseModel, Field

from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

from langgraph.prebuilt import tools_condition


openai_model_str = "gpt-4o-mini-2024-07-18"
REWRITE_LIMIT  = 3

In [None]:
def extract_sql_str(raw_str):
  pattern = r"""```sql\n(.*?)\n```"""
  match = re.search(pattern, raw_str, re.DOTALL)  # re.DOTALL allows matching across lines.

  if match:
    return match.group(1).strip()  # Extract and clean the content.
  else:
    return None


def write_sql(state):
  """
    Given user's question,
      1. LLM to check whether there's available table, if not generate the answer directly.
      2. Otherwise, write the SQL query to answer user's question.
  """
  print("---WRITE SQL---")
  msgs = state["messages"]
  user_question = msgs[0].content

  model = ChatOpenAI(temperature=0, api_key=OPENAI_API_KEY,
                       model=openai_model_str, streaming=True)
  prompt = [
      SystemMessage(content="You are a SQL expert able to write SQL queries to answer user's questions."),
      HumanMessage(
          content=f""" \n
                  Given every table's schema in the database: {all_table_schemas} \n
                  and here's user's question: {user_question} \n
                  if there's no available table to answer the user question, output 'No table found' \n
                  otherwise, output the SQL query to answer user's question: """,
      )
  ]
  result = model.invoke(prompt).content
  if 'No table found' in result:
    output = 'No table found'
  else:
    output = extract_sql_str(result)

  return {"messages": [output], "rewrite_ct": 0}


def validate_sql(state) -> Literal["write_sql", "execute_sql"]:
  """
    Validate the generated SQL query.
  """
  print("---VALIDATE SQL---")
  msgs = state["messages"]
  generated_sql = msgs[-1].content
  user_question = msgs[0].content

  class grade(BaseModel):
        """Binary score for SQL logic validation."""
        binary_score: str = Field(description="Whether the SQL logic is valid 'yes' or 'no'")

  model = ChatOpenAI(temperature=0, api_key=OPENAI_API_KEY,
                       model=openai_model_str, streaming=True)
  llm_with_tool = model.with_structured_output(grade)
  prompt = [
      SystemMessage(content="You are a SQL expert able to write SQL queries to answer user's questions."),
      HumanMessage(
          content=f""" \n
                  Given every table's schema in the database: {all_table_schemas} \n
                  and here's user's question: {user_question}. \n
                  Does generated SQL {generated_sql} answers user's question?
                  If so, output 'YES', otherwise output 'NO' """,
      )
  ]
  chain = prompt | llm_with_tool
  validation_result = chain.invoke().binary_score




def execute_sql(state):
  """
    Execute the generated SQL query.
  """
  print("---EXECUTE SQL---")
  generated_sql = state["messages"][-1].content

  try:
    result = db.run(generated_sql)
    return {"messages": [result],
            "success_execution": True}
  except:
    print('SQL Execution Error')
    return {"messages": [generated_sql],
            "success_execution": False}


def rewrite_sql(state):
  """
    Rewrite the generated SQL query.
  """
  rewrite_ct = state["rewrite_ct"] + 1
  if rewrite_ct >= REWRITE_LIMIT:
    return {"messages": ["Can't find the answer"], "rewrite_ct": rewrite_ct}

  print("---REWRITE SQL---")
  msgs = state["messages"]
  error_sql = msgs[-1].content
  user_question = msgs[0].content

  model = ChatOpenAI(temperature=0, api_key=OPENAI_API_KEY,
                       model=openai_model_str, streaming=True)

  prompt = [
      SystemMessage(content="You are a SQL expert able to write SQL queries to answer user's questions."),
      HumanMessage(
          content=f""" \n
                  Given every table's schema in the database: {all_table_schemas} \n
                  and here's user's question: {user_question}. \n
                  Previously written SQL {error_sql} got execution error \n
                  please find another way to write the SQL query to answer user's question: """,
      )
  ]
  rewritten_sql = extract_sql_str(model.invoke(prompt).content)

  return {"messages": [rewritten_sql], "rewrite_ct": rewrite_ct}


def output_answer(state):
  """
    Output the final results
  """
  print("---OUTPUT ANSWER---")
  msgs = state["messages"]
  final_output = msgs[-1].content

  return {"messages": [final_output]}