<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/t2sql_agent_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Sample database (in-memory)
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
cursor.execute('''CREATE TABLE employees
             (id INTEGER PRIMARY KEY, name TEXT, department TEXT, salary REAL)''')
cursor.execute("INSERT INTO employees VALUES (1, 'Alice', 'Sales', 60000)")
cursor.execute("INSERT INTO employees VALUES (2, 'Bob', 'Marketing', 70000)")
cursor.execute("INSERT INTO employees VALUES (3, 'Charlie', 'Sales', 65000)")
conn.commit()

# Check if a GPU is available
device = 0 if torch.cuda.is_available() else -1

# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl").to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
import sqlite3

# Sample database (in-memory)
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
cursor.execute('''CREATE TABLE employees
             (id INTEGER PRIMARY KEY, name TEXT, department TEXT, salary REAL)''')
cursor.execute("INSERT INTO employees VALUES (1, 'Alice', 'Sales', 60000)")
cursor.execute("INSERT INTO employees VALUES (2, 'Bob', 'Marketing', 70000)")
cursor.execute("INSERT INTO employees VALUES (3, 'Charlie', 'Sales', 65000)")
conn.commit()

def t2sql_agent(query):
    """
    Translates a natural language query into SQL using a rule-based approach.

    Args:
      query: The natural language query.

    Returns:
      The result of the SQL query.
    """
    query = query.lower()
    if "highest salary" in query:
        sql_query = "SELECT MAX(salary) FROM employees"
    elif "average salary" in query and "marketing" in query:
        sql_query = "SELECT AVG(salary) FROM employees WHERE department = 'Marketing'"
    elif "employees" in query and "sales" in query:
        sql_query = "SELECT * FROM employees WHERE department = 'Sales'"
    elif "all employees" in query:  # New rule for retrieving all employees
        sql_query = "SELECT * FROM employees"
    else:
        return "I don't understand that query."

    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
        return results
    except Exception as e:
        return f"Error executing SQL query: {e}"

# Example usage with labels
user_queries = [
    "What is the highest salary in the company?",
    "Show all employees working in the Sales department",
    "What is the average salary of employees in the Marketing department?",
    "Show all employees"  # New query
]

for user_query in user_queries:
    print(f"Query: {user_query}")
    results = t2sql_agent(user_query)
    print(f"Result: {results}\n")

Query: What is the highest salary in the company?
Result: [(70000.0,)]

Query: Show all employees working in the Sales department
Result: [(1, 'Alice', 'Sales', 60000.0), (3, 'Charlie', 'Sales', 65000.0)]

Query: What is the average salary of employees in the Marketing department?
Result: [(70000.0,)]

Query: Show all employees
Result: [(1, 'Alice', 'Sales', 60000.0), (2, 'Bob', 'Marketing', 70000.0), (3, 'Charlie', 'Sales', 65000.0)]



In [10]:
def t2sql_agent_with_llm(query):
    """
    Translates a natural language query into SQL using an LLM.

    Args:
      query: The natural language query.

    Returns:
      The result of the SQL query.
    """

    # Use the LLM to generate the SQL query
    prompt = f"""
    You are a SQL expert. Your task is to translate natural language queries into correct SQL queries that can be executed on a database.
    The database has one table named 'employees' with the following columns: id, name, department, and salary.

    ### Examples:
    Natural Language Query: What is the name of the employee with the highest salary?
    SQL Query: SELECT name FROM employees ORDER BY salary DESC LIMIT 1

    Natural Language Query: What are the names of all employees in the Sales department?
    SQL Query: SELECT name FROM employees WHERE department = 'Sales'

    Natural Language Query: What is the average salary in the Marketing department?
    SQL Query: SELECT AVG(salary) FROM employees WHERE department = 'Marketing'

    ### Now, write ONLY the SQL query that answers the following natural language query:
    {query}
    SQL Query:"""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # Increased max_length to accommodate the input sequence length
    outputs = model.generate(**inputs, max_length=300)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Remove the prompt and any leading/trailing whitespace from the generated SQL
    generated_sql = generated_sql.replace(prompt, "").strip()

    try:
        # Execute the generated SQL query
        cursor.execute(generated_sql)
        results = cursor.fetchall()
        return results
    except Exception as e:
        return f"Error executing SQL query: {generated_sql} - {e}"

In [17]:
def t2sql_agent_with_llm(query):
    """
    Translates a natural language query into SQL using an LLM with a highly guided prompt.

    Args:
      query: The natural language query.

    Returns:
      The result of the SQL query.
    """

    # Use the LLM to generate the SQL query
    prompt = f"""
    You are a helpful AI assistant that translates natural language queries to SQL.
    The database has one table named 'employees' with the following columns: id, name, department, and salary.
    You MUST generate a valid SQL query that accurately reflects the information requested in the natural language query.
    Pay close attention to the column names and ensure all necessary columns are included in the SELECT statement.

    Example 1:
    Natural Language Query: What is the highest salary in the company?
    SQL Query: SELECT MAX(salary) FROM employees

    Example 2:
    Natural Language Query: Show all employees working in the Sales department.
    SQL Query: SELECT * FROM employees WHERE department = 'Sales'

    Example 3:
    Natural Language Query: What is the average salary of employees in the Marketing department?
    SQL Query: SELECT AVG(salary) FROM employees WHERE department = 'Marketing'

    Natural Language Query: {query}
    SQL Query:"""

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_length=150)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Remove the prompt and any leading/trailing whitespace from the generated SQL
    generated_sql = generated_sql.replace(prompt, "").strip()

    try:
        # Execute the generated SQL query
        cursor.execute(generated_sql)
        results = cursor.fetchall()
        return results
    except Exception as e:
        return f"Error executing SQL query: {generated_sql} - {e}"

In [18]:
# Example usage with labels
user_queries = [
    "What is the highest salary in the company?",
    "Show all employees working in the Sales department",
    "What is the average salary of employees in the Marketing department?",
    "Show all employees"
]

for user_query in user_queries:
    print(f"Query: {user_query}")
    results = t2sql_agent_with_llm(user_query)  # Using the LLM function
    print(f"Result: {results}\n")

Query: What is the highest salary in the company?
Result: [(70000.0,)]

Query: Show all employees working in the Sales department
Result: [(1, 'Alice', 'Sales', 60000.0), (3, 'Charlie', 'Sales', 65000.0)]

Query: What is the average salary of employees in the Marketing department?
Result: [(70000.0,)]

Query: Show all employees
Result: [(1, 'Alice', 'Sales', 60000.0), (2, 'Bob', 'Marketing', 70000.0), (3, 'Charlie', 'Sales', 65000.0)]

