In [None]:
import os
import sqlite3
from dotenv import load_dotenv
from openai import OpenAI

# Load environment variables
load_dotenv()

# Initialize OpenAI client
client = OpenAI()

# Connect to SQLite database
conn = sqlite3.connect('chinook.db')
cursor = conn.cursor()

# Define the schema of the database
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
tables = [table[0] for table in tables]
print(f"Available tables: {tables}")


Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [None]:
import json
import pandas as pd
from typing import List, Dict, Any, Optional

def query_sql_database(query: str) -> str:
    """
    Execute a SQL query against the Chinook database.
    
    Args:
        query (str): A detailed and correct SQL query to execute
        
    Returns:
        str: JSON string containing query results or error message
    """
    try:
        # Connect to the database
        conn = sqlite3.connect('chinook.db')
        
        # Execute the query and get results as DataFrame
        df = pd.read_sql_query(query, conn)
        conn.close()
        
        # Convert to JSON format
        if df.empty:
            return json.dumps({"message": "Query executed successfully but returned no results."})
        
        # Limit results to prevent overwhelming output
        if len(df) > 100:
            result_data = {
                "message": f"Query returned {len(df)} rows. Showing first 100 rows.",
                "data": df.head(100).to_dict('records'),
                "total_rows": len(df)
            }
        else:
            result_data = {
                "message": f"Query returned {len(df)} rows.",
                "data": df.to_dict('records'),
                "total_rows": len(df)
            }
            
        return json.dumps(result_data, indent=2)
        
    except Exception as e:
        error_msg = {
            "error": str(e),
            "message": "Query failed. Please check your SQL syntax and table/column names."
        }
        return json.dumps(error_msg, indent=2)

def info_sql_database(tables: str) -> str:
    """
    Get schema information and sample rows for specified tables.
    
    Args:
        tables (str): Comma-separated list of table names
        
    Returns:
        str: JSON string containing schema and sample data for the tables
    """
    try:
        conn = sqlite3.connect('chinook.db')
        cursor = conn.cursor()
        
        table_list = [table.strip() for table in tables.split(',')]
        result = {"tables": {}}
        
        for table_name in table_list:
            # Check if table exists
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
            if not cursor.fetchone():
                result["tables"][table_name] = {"error": f"Table '{table_name}' does not exist"}
                continue
            
            # Get table schema
            cursor.execute(f"PRAGMA table_info({table_name})")
            columns = cursor.fetchall()
            
            schema = []
            for col in columns:
                schema.append({
                    "column_name": col[1],
                    "data_type": col[2],
                    "not_null": bool(col[3]),
                    "primary_key": bool(col[5])
                })
            
            # Get sample rows (first 3)
            cursor.execute(f"SELECT * FROM {table_name} LIMIT 3")
            sample_rows = cursor.fetchall()
            
            # Get column names for sample data
            column_names = [desc[0] for desc in cursor.description]
            sample_data = []
            for row in sample_rows:
                sample_data.append(dict(zip(column_names, row)))
            
            result["tables"][table_name] = {
                "schema": schema,
                "sample_rows": sample_data,
                "row_count": len(sample_data)
            }
        
        conn.close()
        return json.dumps(result, indent=2)
        
    except Exception as e:
        error_msg = {
            "error": str(e),
            "message": "Failed to retrieve table information."
        }
        return json.dumps(error_msg, indent=2)

def list_sql_database() -> str:
    """
    List all available tables in the Chinook database.
    
    Returns:
        str: JSON string containing list of all tables
    """
    try:
        conn = sqlite3.connect('chinook.db')
        cursor = conn.cursor()
        
        # Get all table names
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = cursor.fetchall()
        
        table_names = [table[0] for table in tables]
        
        result = {
            "message": f"Found {len(table_names)} tables in the database",
            "tables": table_names
        }
        
        conn.close()
        return json.dumps(result, indent=2)
        
    except Exception as e:
        error_msg = {
            "error": str(e),
            "message": "Failed to list database tables."
        }
        return json.dumps(error_msg, indent=2)

def query_sql_checker(query: str) -> str:
    """
    Validate a SQL query for common mistakes before execution.
    
    Args:
        query (str): SQL query to validate
        
    Returns:
        str: JSON string with validation results and corrected query if needed
    """
    try:
        # Use OpenAI to check the query
        system_prompt = """You are an expert SQLite query checker and fixer.
            Your job is to analyze a given SQLite query, identify any syntax or logical errors, and return a corrected version if needed.
            You must always respond only in the following JSON format (no extra commentary or markdown):
            {"has_error": boolean, "corrected_query": "string", "explanation": "string" }

            If the query is valid, has_error should be false, and corrected_query should match the input.

            If there are issues, has_error should be true, and corrected_query should contain the fixed query.

            In explanation, briefly describe what was fixed or state "Query is valid." if there was nothing to change.

            When checking, look for:

            Common typos in SQL keywords (e.g. SELEC → SELECT, FORM → FROM)

            Incorrect table or column references (e.g. missing FROM clause)

            Missing or misplaced WHERE, JOIN, or GROUP BY clauses

            Improper string or identifier quoting

            Incomplete statements (e.g. missing semicolon, unclosed parentheses)

            SQLite-specific limitations (e.g. no RIGHT JOIN)

            Only return the JSON. Do not include explanations outside the JSON object or any formatting.
        """
        
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Validate this SQLite query: {query}"}
            ],
            temperature=0.1
        )
        
        # Try to parse the response as JSON
        try:
            content = response.choices[0].message.content
            if content is None:
                raise json.JSONDecodeError("Empty response", "", 0)
            result = json.loads(content)
        except json.JSONDecodeError:
            # If not JSON, create a structured response
            result = {
                "has_errors": False,
                "errors": [],
                "corrected_query": query,
                "explanation": response.choices[0].message.content
            }
        
        return json.dumps(result, indent=2)
        
    except Exception as e:
        error_msg = {
            "error": str(e),
            "message": "Failed to validate query."
        }
        return json.dumps(error_msg, indent=2)

f

In [11]:
# OpenAI Tool Definitions (Function Calling Format)
tools = [
    {
        "type": "function",
        "function": {
            "name": "query_sql_database",
            "description": "Execute a SQL query against the Chinook database. Returns results or error message.",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "A detailed and correct SQL query to execute against the database"
                    }
                },
                "required": ["query"]
            }
        }
    },
    {
        "type": "function", 
        "function": {
            "name": "info_sql_database",
            "description": "Get schema information and sample rows for specified tables. Use list_sql_database first to see available tables.",
            "parameters": {
                "type": "object",
                "properties": {
                    "tables": {
                        "type": "string",
                        "description": "Comma-separated list of table names to get information for (e.g., 'Artist, Album, Track')"
                    }
                },
                "required": ["tables"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "list_sql_database", 
            "description": "List all available tables in the Chinook database.",
            "parameters": {
                "type": "object",
                "properties": {},
                "required": []
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "query_sql_checker",
            "description": "Validate a SQL query for common mistakes before execution. Always use this before executing queries.",
            "parameters": {
                "type": "object", 
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "SQL query to validate for syntax errors and common mistakes"
                    }
                },
                "required": ["query"]
            }
        }
    }
]

# Function mapping for tool execution
available_functions = {
    "query_sql_database": query_sql_database,
    "info_sql_database": info_sql_database,
    "list_sql_database": list_sql_database,
    "query_sql_checker": query_sql_checker
}

print("OpenAI Tool Definitions created successfully!")
print("Tools available for function calling:")
for tool in tools:
    print(f"- {tool['function']['name']}: {tool['function']['description']}")


OpenAI Tool Definitions created successfully!
Tools available for function calling:
- query_sql_database: Execute a SQL query against the Chinook database. Returns results or error message.
- info_sql_database: Get schema information and sample rows for specified tables. Use list_sql_database first to see available tables.
- list_sql_database: List all available tables in the Chinook database.
- query_sql_checker: Validate a SQL query for common mistakes before execution. Always use this before executing queries.


In [15]:
# # Test the functions individually
# print("=== Testing list_sql_database ===")
# result = list_sql_database()
# print(result)

# print("\n=== Testing info_sql_database ===")
# result = info_sql_database("Artist, Album")
# print(result)

print("\n=== Testing query_sql_checker ===")
test_query = "SELECT SELECT INTO * FROM Artist WHERE ArtistId = 1"
result = query_sql_checker(test_query)
print(result)

print("\n=== Testing query_sql_database ===")
result = query_sql_database("SELECT * FROM Artist LIMIT 5")
print(result)



=== Testing query_sql_checker ===
{
  "has_error": true,
  "corrected_query": "SELECT * FROM Artist WHERE ArtistId = 1;",
  "explanation": "Removed the incorrect 'SELECT INTO' and added a semicolon at the end."
}

=== Testing query_sql_database ===
{
  "message": "Query returned 5 rows.",
  "data": [
    {
      "ArtistId": 1,
      "Name": "AC/DC"
    },
    {
      "ArtistId": 2,
      "Name": "Accept"
    },
    {
      "ArtistId": 3,
      "Name": "Aerosmith"
    },
    {
      "ArtistId": 4,
      "Name": "Alanis Morissette"
    },
    {
      "ArtistId": 5,
      "Name": "Alice In Chains"
    }
  ],
  "total_rows": 5
}


In [16]:
system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect="SQLite",
    top_k=5,
)

In [17]:
import json

def execute_function_call(function_name: str, arguments: dict) -> str:
    """Execute a function call with the given arguments"""
    if function_name in available_functions:
        function = available_functions[function_name]
        try:
            if function_name == "list_sql_database":
                return function()
            else:
                return function(**arguments)
        except Exception as e:
            return json.dumps({"error": str(e)})
    else:
        return json.dumps({"error": f"Function {function_name} not found"})

class SQLAgentWithTools:
    """SQL Agent that uses OpenAI function calling with SQL tools"""
    
    def __init__(self, openai_client):
        self.client = openai_client
        self.tools = tools
        self.available_functions = available_functions
        
    def ask(self, question: str, max_iterations: int = 5) -> str:
        """Ask a question and let the agent use tools to answer it"""
        
        messages = [
            {
                "role": "system", 
                "content": system_message
            },
            {"role": "user", "content": question}
        ]
        
        iteration = 0
        while iteration < max_iterations:
            iteration += 1
            
            # Make API call with tools
            response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                tools=self.tools,
                tool_choice="auto"
            )
            
            response_message = response.choices[0].message
            messages.append(response_message)
            
            # Check if the model wants to call functions
            if response_message.tool_calls:
                # Execute each function call
                for tool_call in response_message.tool_calls:
                    function_name = tool_call.function.name
                    function_args = json.loads(tool_call.function.arguments)
                    
                    print(f"🔧 Calling {function_name} with args: {function_args}")
                    
                    # Execute the function
                    function_response = execute_function_call(function_name, function_args)
                    
                    # Add function response to messages
                    messages.append({
                        "tool_call_id": tool_call.id,
                        "role": "tool",
                        "name": function_name,
                        "content": function_response
                    })
            else:
                # No more function calls, return the final response
                return response_message.content
        
        return "Maximum iterations reached. Please try a simpler question."

# Initialize the SQL Agent with tools
sql_agent = SQLAgentWithTools(client)


SQL Agent with OpenAI Tools initialized successfully!
Ready to answer questions about the Chinook database!


In [19]:
# Example usage of the SQL Agent with Tools
print("=== Example 1: Basic Question ===")
response = sql_agent.ask("What are the top 5 best-selling artists by total sales?")
print(response)

print("\n" + "="*60)
print("=== Example 2: Complex Query ===") 
response = sql_agent.ask("Show me the average track length by genre in seconds, but only for genres with more than 10 tracks")
print(response)


=== Example 1: Basic Question ===
🔧 Calling list_sql_database with args: {}
🔧 Calling info_sql_database with args: {'tables': 'Artist, Invoice, InvoiceLine, Track'}
🔧 Calling info_sql_database with args: {'tables': 'Album'}
🔧 Calling query_sql_checker with args: {'query': 'SELECT Ar.Name, SUM(Il.UnitPrice * Il.Quantity) AS TotalSales \nFROM Artist Ar\nJOIN Album Al ON Ar.ArtistId = Al.ArtistId\nJOIN Track Tr ON Al.AlbumId = Tr.AlbumId\nJOIN InvoiceLine Il ON Tr.TrackId = Il.TrackId\nGROUP BY Ar.ArtistId\nORDER BY TotalSales DESC\nLIMIT 5;'}
🔧 Calling query_sql_database with args: {'query': 'SELECT Ar.Name, SUM(Il.UnitPrice * Il.Quantity) AS TotalSales \nFROM Artist Ar\nJOIN Album Al ON Ar.ArtistId = Al.ArtistId\nJOIN Track Tr ON Al.AlbumId = Tr.AlbumId\nJOIN InvoiceLine Il ON Tr.TrackId = Il.TrackId\nGROUP BY Ar.ArtistId\nORDER BY TotalSales DESC\nLIMIT 5;'}
The top 5 best-selling artists by total sales are:

1. **Iron Maiden** with total sales of $138.60
2. **U2** with total sales of 