In [65]:
! pip install langchain python-dotenv langgraph langchain-openai langchain-community langchain-core




In [66]:
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.messages import ToolMessage

load_dotenv()

openai_api_key = os.getenv("OPENAI_API_KEY")
# database_url = "sqlite:///grades.db"
database_url = os.getenv("DATABASE_URL")

db = SQLDatabase.from_uri(database_url)

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=openai_api_key)
# result = llm.invoke("What is the average height of trees in india?" )
# print(result.content)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type="openai-tools",
    max_iterations=15,
    max_execution_time=30.0
)

agent.invoke({"input": "What grades alice got?"})



# sqlalchamy create engine
# CREATE TABLE students (
#     id INTEGER PRIMARY KEY AUTOINCREMENT,
#     name TEXT NOT NULL,
#     subject TEXT NOT NULL,
#     grade INTEGER NOT NULL
# );











[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mstudents[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'students'}`


[0m[33;1m[1;3m
CREATE TABLE students (
	id INTEGER, 
	name TEXT NOT NULL, 
	subject TEXT NOT NULL, 
	grade INTEGER NOT NULL, 
	PRIMARY KEY (id)
)

/*
3 rows from students table:
id	name	subject	grade
1	Alice Johnson	Math	95
2	Bob Smith	Math	87
3	Carol Davis	Math	92
*/[0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': "SELECT subject, grade FROM students WHERE name = 'Alice Johnson'"}`


[0m[36;1m[1;3m```sql
SELECT subject, grade FROM students WHERE name = 'Alice Johnson'
```[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': "SELECT subject, grade FROM students WHERE name = 'Alice Johnson'"}`


[0m[36;1m[1;3m[('Math', 95), ('Science', 88), ('English', 92), ('History', 90)][0m[32;1m[1;3mAlice received the following grades:

- Math: 95
-

{'input': 'What grades alice got?',
 'output': 'Alice received the following grades:\n\n- Math: 95\n- Science: 88\n- English: 92\n- History: 90'}

Building SQL Agents with LangGraph


In [67]:
# Fixed LangGraph SQL Agent Implementation
from typing import Annotated, List
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import sessionmaker
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import AIMessage, HumanMessage
import json

# Database connection configuration
DATABASE_URL = "sqlite:///grades.db"
engine = create_engine(DATABASE_URL, echo=False)
Session = sessionmaker(bind=engine)

# Initialize LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=openai_api_key)

class SqlAgentState(TypedDict):
    messages: Annotated[List[AnyMessage], add_messages]
    user_question: str
    sql_query: str
    query_result: str
    final_answer: str
    error: str

def list_tables_node(state: SqlAgentState) -> SqlAgentState:
    """List all tables in the database"""
    try:
        inspector = inspect(engine)
        tables = inspector.get_table_names()
        return {
            "messages": [AIMessage(content=f"Available tables: {', '.join(tables)}")]
        }
    except Exception as e:
        return {
            "error": str(e),
            "messages": [AIMessage(content=f"Error listing tables: {str(e)}")]
        }

def get_schema_node(state: SqlAgentState) -> SqlAgentState:
    """Get the schema of the database"""
    try:
        inspector = inspect(engine)
        schema_info = []
        for table_name in inspector.get_table_names():
            columns = inspector.get_columns(table_name)
            column_info = [f"{col['name']} ({col['type']})" for col in columns]
            schema_info.append(f"Table {table_name}: {', '.join(column_info)}")
        schema_str = "\n".join(schema_info)
        return {
            "messages": [AIMessage(content=f"Database schema:\n{schema_str}")]
        }
    except Exception as e:
        return {
            "error": str(e),
            "messages": [AIMessage(content=f"Error retrieving schema: {str(e)}")]
        }

def query_generation_node(state: SqlAgentState) -> SqlAgentState:
    """Generate a SQL query from the user question"""
    try:
        # Extract user question from the initial message
        user_question = state.get("user_question", "")
        if not user_question:
            user_question = next((msg.content for msg in state["messages"] if isinstance(msg, HumanMessage)), "")
        
        # Get schema information from previous messages
        schema_message = next((msg.content for msg in state["messages"] if "Database schema:" in msg.content), "")
        
        # Construct prompt for query generation
        query_prompt = f"""
        You are a SQL expert. Based on the user question and database schema, generate a valid SQLite query.
        
        User question: {user_question}
        
        Database schema:
        {schema_message}
        
        Generate a SQL query to answer the question. Ensure the query is:
        - Syntactically correct for SQLite
        - Matches column names and table names exactly
        - Handles case sensitivity (e.g., 'Alice' vs 'alice')
        - Return ONLY the SQL query without any markdown formatting or extra text
        
        Example: SELECT * FROM students WHERE name = 'Alice Johnson'
        """
        
        # Use LLM to generate the SQL query
        response = llm.invoke(query_prompt)
        sql_query = response.content.strip()
        
        # Clean up the query - remove markdown formatting if present
        if sql_query.startswith("```sql"):
            sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
        elif sql_query.startswith("```"):
            sql_query = sql_query.replace("```", "").strip()
        
        return {
            "sql_query": sql_query,
            "user_question": user_question,
            "messages": [AIMessage(content=f"Generated SQL query: {sql_query}")]
        }
    except Exception as e:
        return {
            "error": str(e),
            "messages": [AIMessage(content=f"Error generating query: {str(e)}")]
        }

def execute_query_node(state: SqlAgentState) -> SqlAgentState:
    """Execute a SQL query and return the results"""
    try:
        sql_query = state.get("sql_query", "")
        if not sql_query:
            return {
                "error": "No SQL query provided",
                "messages": [AIMessage(content="Error: No SQL query provided")]
            }
        
        with engine.connect() as connection:
            result = connection.execute(text(sql_query))
            columns = result.keys()
            rows = result.fetchall()
            
            # Format results
            result_data = [dict(zip(columns, row)) for row in rows]
            
            # Generate human-readable answer
            user_question = state.get("user_question", "").lower()
            if "alice" in user_question and "grade" in user_question:
                if result_data:
                    # Group by subject and grade
                    subjects_grades = [(row.get("subject", ""), row.get("grade", "")) for row in result_data]
                    grade_text = []
                    for subject, grade in subjects_grades:
                        grade_text.append(f"{subject}: {grade}")
                    final_answer = f"Alice Johnson received the following grades:\n" + "\n".join([f"- {text}" for text in grade_text])
                else:
                    final_answer = "No grades found for Alice."
            else:
                # Generic response
                if result_data:
                    final_answer = f"Query returned {len(result_data)} result(s):\n" + json.dumps(result_data, indent=2)
                else:
                    final_answer = "No results found."
            
            return {
                "query_result": json.dumps(result_data, indent=2),
                "final_answer": final_answer,
                "messages": [AIMessage(content=final_answer)]
            }
    except Exception as e:
        return {
            "error": str(e),
            "messages": [AIMessage(content=f"Error executing query: {str(e)}")]
        }

# Create workflow
workflow = StateGraph(SqlAgentState)

# Add nodes
workflow.add_node("list_tables", list_tables_node)
workflow.add_node("get_schema", get_schema_node)
workflow.add_node("query_gen", query_generation_node)
workflow.add_node("execute_query", execute_query_node)

# Define processing flow
workflow.add_edge(START, "list_tables")
workflow.add_edge("list_tables", "get_schema")
workflow.add_edge("get_schema", "query_gen")
workflow.add_edge("query_gen", "execute_query")
workflow.add_edge("execute_query", END)

# Compile and run the workflow
app = workflow.compile()

# Example invocation
result = app.invoke({
    "messages": [HumanMessage(content="What grades did Alice get?")], 
    "user_question": "What grades did Alice get?"
})

print("=== Final Result ===")
print(f"Final Answer: {result.get('final_answer', 'No final answer')}")
print(f"SQL Query Used: {result.get('sql_query', 'No query')}")
print(f"Any Errors: {result.get('error', 'None')}")


=== Final Result ===
Final Answer: No grades found for Alice.
SQL Query Used: SELECT grade FROM students WHERE name = 'Alice'
Any Errors: None


In [68]:
# Test the SQL agent with different queries
test_queries = [
    "What grades did Alice Johnson get?",
    "Who got the highest grade in Math?",
    "What is the average grade for each subject?",
    "How many students are there in total?"
]

print("Testing SQL Agent with different queries:\n")
print("=" * 60)

for i, query in enumerate(test_queries, 1):
    print(f"\n{i}. Query: {query}")
    print("-" * 40)
    
    try:
        result = app.invoke({
            "messages": [HumanMessage(content=query)], 
            "user_question": query
        })
        
        print(f"Answer: {result.get('final_answer', 'No answer')}")
        print(f"SQL: {result.get('sql_query', 'No query')}")
        
        if result.get('error'):
            print(f"Error: {result['error']}")
            
    except Exception as e:
        print(f"Error running query: {str(e)}")
    
    print("=" * 60)


Testing SQL Agent with different queries:


1. Query: What grades did Alice Johnson get?
----------------------------------------
Answer: Alice Johnson received the following grades:
- : 95
- : 88
- : 92
- : 90
SQL: SELECT grade FROM students WHERE name = 'Alice Johnson'

2. Query: Who got the highest grade in Math?
----------------------------------------
Answer: Query returned 1 result(s):
[
  {
    "name": "Grace Lee",
    "grade": 96
  }
]
SQL: SELECT name, grade FROM students WHERE subject = 'Math' ORDER BY grade DESC LIMIT 1

3. Query: What is the average grade for each subject?
----------------------------------------
Answer: Query returned 4 result(s):
[
  {
    "subject": "English",
    "average_grade": 88.46153846153847
  },
  {
    "subject": "History",
    "average_grade": 88.23076923076923
  },
  {
    "subject": "Math",
    "average_grade": 89.1923076923077
  },
  {
    "subject": "Science",
    "average_grade": 88.92307692307692
  }
]
SQL: SELECT subject, AVG(grade) AS a