In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
import sqlite3
import time

from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langsmith import Client
from langsmith.evaluation import (
    EvaluationResult,
    EvaluationResults,
    RunEvaluator,
    evaluate,
)

from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.prompts import PromptTemplate

In [3]:
# Create a SQLite database and add some sample data
conn = sqlite3.connect("sample.db")
cursor = conn.cursor()

# Create a sample table
db_schema = """
CREATE TABLE IF NOT EXISTS users (
    id INTEGER PRIMARY KEY,
    name TEXT NOT NULL,
    age INTEGER,
    email TEXT
)
"""

cursor.execute(db_schema)

# Insert some sample data
sample_data = [
    (1, "Alice", 28, "alice@email.com"),
    (2, "Bob", 35, "bob@email.com"),
    (3, "Charlie", 42, "charlie@email.com"),
    (4, "David", 31, "david@email.com"),
]
cursor.executemany(
    "INSERT OR REPLACE INTO users (id, name, age, email) VALUES (?, ?, ?, ?)",
    sample_data,
)
conn.commit()
conn.close()

In [4]:
# Create a dataset in LangSmith (if not already created)
client = Client()
dataset_name = "Text-to-SQL Evaluation Dataset"
client.delete_dataset(dataset_name=dataset_name)
dataset = client.create_dataset(dataset_name)
instruction = "Generate the SQL query based on the given natural language query."

# Add examples to the dataset
examples = [
    {
        "input": f"{instruction}\nquestion: How many users are in the database?\db schema: {db_schema}",
        "expected_output": "SELECT COUNT(*) FROM users;",
    },
    {
        "input": f"{instruction}\nWhat is the average age of users?\ndb schema: {db_schema}",
        "expected_output": "SELECT AVG(age) FROM users;",
    },
    {
        "input": f"{instruction}\nList all users older than 30\ndb schema: {db_schema}",
        "expected_output": "SELECT * FROM users WHERE age > 30;",
    },
    {
        "input": f"{instruction}\nWhat is the email of the oldest user?\ndb schema: {db_schema}",
        "expected_output": "SELECT email FROM users WHERE age = (SELECT MAX(age) FROM users);",
    },
]

for example in examples:
    client.create_example(
        inputs={"query": example["input"]},
        outputs={"expected_sql": example["expected_output"]},
        dataset_id=dataset.id,
    )

In [5]:
def get_query_plan(query):
    """
    クエリの実行計画を取得する関数です。

    Args:
        query (str): 実行計画を取得したいクエリ

    Returns:
        list: クエリの実行計画の結果を含むリスト

    Examples:
        >>> query = "SELECT * FROM users"
        >>> get_query_plan(query)
        [('0', '0', '0', 'SCAN TABLE users')]

    Raises:
        None
    """
    conn = sqlite3.connect("sample.db")
    cursor = conn.cursor()
    cursor.execute(f"EXPLAIN QUERY PLAN {query}")
    plan = cursor.fetchall()
    conn.close()
    return plan


def measure_execution_time(query: str) -> float:
    """
    クエリの実行時間を計測する関数です。

    Args:
        query (str): 実行するクエリ文

    Returns:
        float: クエリの実行時間（秒単位）

    Examples:
        >>> query = "SELECT * FROM table"
        >>> execution_time = measure_execution_time(query)
        >>> print(execution_time)
        0.123456789
    """
    conn = sqlite3.connect("sample.db")
    cursor = conn.cursor()
    start_time = time.time()
    cursor.execute(query)
    end_time = time.time()
    conn.close()
    return end_time - start_time


def calculate_efficiency_score(query_plan: list, execution_time: float) -> float:
    """
    クエリの効率スコアを計算します。

    Args:
        query_plan (list): クエリの実行計画を表すリスト
        execution_time (float): クエリの実行時間（秒単位）

    Returns:
        float: 計算された効率スコア（0.1から1.0の範囲）

    Raises:
        None

    Examples:
        >>> query_plan = [("SCAN TABLE", "users"), ("USE TEMP B-TREE", "index")]
        >>> execution_time = 0.05
        >>> calculate_efficiency_score(query_plan, execution_time)
        0.72
    """
    # This is a simple heuristic and can be improved
    score = 1.0
    for step in query_plan:
        if "SCAN TABLE" in step[3]:
            score *= 0.8  # フルテーブルスキャンをペナルティとして適用
        if "USE TEMP B-TREE" in step[3]:
            score *= 0.9  # 一時的なBツリーの使用をペナルティとして適用

    # 実行時間が長い場合にペナルティを適用
    if execution_time > 0.1:  # 必要に応じてこの閾値を調整してください
        score *= 0.9

    return max(0.1, min(score, 1.0))  # スコアが0.1から1.0の範囲になるように調整

In [6]:
# LLMでSQLクエリを評価するカスタム評価器を作成
class LLMSQLEvaluator(RunEvaluator):
    def __init__(self):
        self.llm = ChatOpenAI(temperature=0, model="gpt-4o")
        self.prompt = PromptTemplate(
            input_variables=["query", "generated_sql", "expected_sql", "db_schema"],
            template="""
            You are an expert SQL reviewer. Your task is to evaluate the generated SQL query against the expected SQL query for a given natural language question. Consider the database schema provided.

            Database Schema:
            {db_schema}

            Natural Language Query: {query}
            Generated SQL: {generated_sql}
            Expected SQL: {expected_sql}

            Please evaluate the generated SQL based on the following criteria:
            1. Correctness: Does the generated SQL correctly answer the natural language query?
            2. Efficiency: Is the generated SQL optimized and efficient?
            3. Readability: Is the generated SQL easy to read and understand?

            Provide a score from 0 to 1 (0 being the worst, 1 being the best) for each criterion, and an overall score.
            Also, provide a brief explanation for your evaluation.

            Format your response as follows:
            Correctness Score: [score]
            Efficiency Score: [score]
            Readability Score: [score]
            Overall Score: [score]
            Explanation: [your explanation]
            """,
        )
        self.chain = self.prompt | self.llm

    def evaluate_run(self, run, example) -> EvaluationResult:
        """
        モデルの実行結果を評価します。

        Args:
            run (Run): 実行結果を含むオブジェクト
            example (Example): 入力と出力の情報を含むオブジェクト

        Returns:
            EvaluationResult: 評価結果を表すオブジェクト
        """
        query = example.inputs.get("query", "")
        generated_sql = run.outputs.get("output", "")
        expected_sql = example.outputs.get("expected_sql", "")
        db = SQLDatabase.from_uri("sqlite:///sample.db")
        db_schema = db.get_table_info()

        evaluation = self.chain.invoke(
            {
                "query": query,
                "generated_sql": generated_sql,
                "expected_sql": expected_sql,
                "db_schema": db_schema,
            }
        )

        # Parse the evaluation results
        lines = evaluation.content.strip().split("\n")
        scores = {}
        for line in lines:
            if "Score:" in line:  # スコアを取得
                key, value = line.split(":")
                scores[key.strip()] = float(value.strip())

        explanation = lines[-1].replace("Explanation:", "").strip()

        results = [
            EvaluationResult(
                key="overall",
                score=scores.get("Overall Score", 0),
                comment=f"Overall: {explanation}",
            ),
            EvaluationResult(
                key="correctness",
                score=scores.get("Correctness Score", 0),
            ),
            EvaluationResult(
                key="efficiency",
                score=scores.get("Efficiency Score", 0),
            ),
            EvaluationResult(
                key="readability",
                score=scores.get("Readability Score", 0),
            ),
        ]

        # Return EvaluationResults object
        return EvaluationResults(results=results)

In [7]:
llm = ChatOpenAI(temperature=0, model="gpt-4o")


# Define the target function
def target_function(x):
    return llm.invoke(x["query"])


# Run evaluation
results = evaluate(
    target_function,
    dataset.name,
    evaluators=[LLMSQLEvaluator()],
    experiment_prefix="SQL Generation Evaluation",
)

View the evaluation results for experiment: 'SQL Generation Evaluation-03344ca4' at:
https://smith.langchain.com/o/bd14a154-65e7-52b4-bdce-b9a16d5e3513/datasets/f70d7596-e0d8-4ed3-a2f2-320ebf7af329/compare?selectedSessions=ef090748-9000-4271-9670-62a2cb843911




0it [00:00, ?it/s]

In [8]:
for res in results:
    print(f"Evaluation complete. Run ID: {res['run'].id}")
    print(f"View results at: {client.get_run_url(run=res['run'])}")

Evaluation complete. Run ID: 3ebd5604-cc23-4c64-a327-1b8767c3cb46
View results at: https://smith.langchain.com/o/bd14a154-65e7-52b4-bdce-b9a16d5e3513/projects/p/dc1fff1d-e827-4ee4-b948-e71e0914f7dd/r/3ebd5604-cc23-4c64-a327-1b8767c3cb46?poll=true
Evaluation complete. Run ID: 74435322-1710-4f43-b45a-7498de4cc0dd
View results at: https://smith.langchain.com/o/bd14a154-65e7-52b4-bdce-b9a16d5e3513/projects/p/dc1fff1d-e827-4ee4-b948-e71e0914f7dd/r/74435322-1710-4f43-b45a-7498de4cc0dd?poll=true
Evaluation complete. Run ID: e49cb64f-84c0-41f7-9267-35c4c47a824f
View results at: https://smith.langchain.com/o/bd14a154-65e7-52b4-bdce-b9a16d5e3513/projects/p/dc1fff1d-e827-4ee4-b948-e71e0914f7dd/r/e49cb64f-84c0-41f7-9267-35c4c47a824f?poll=true
Evaluation complete. Run ID: f9786659-91ed-49a2-8bdd-b8b9a9d82cd9
View results at: https://smith.langchain.com/o/bd14a154-65e7-52b4-bdce-b9a16d5e3513/projects/p/dc1fff1d-e827-4ee4-b948-e71e0914f7dd/r/f9786659-91ed-49a2-8bdd-b8b9a9d82cd9?poll=true


In [9]:
# SQLiteコネクト用のtoolを作成
db = SQLDatabase.from_uri("sqlite:///sample.db")
llm = ChatOpenAI(temperature=0, model="gpt-4")
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
)

results = evaluate(
    lambda x: agent.invoke(x["query"]),
    dataset.name,
    evaluators=[LLMSQLEvaluator()],
    experiment_prefix="SQL Generation Evaluation",
)

View the evaluation results for experiment: 'SQL Generation Evaluation-3d600262' at:
https://smith.langchain.com/o/bd14a154-65e7-52b4-bdce-b9a16d5e3513/datasets/f70d7596-e0d8-4ed3-a2f2-320ebf7af329/compare?selectedSessions=17b296f3-f68f-4993-b482-d4dd967ea682




0it [00:00, ?it/s]



[1m> Entering new SQL Agent Executor chain...[0m


[1m> Entering new SQL Agent Executor chain...[0m


[1m> Entering new SQL Agent Executor chain...[0m


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3musers[0m[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3musers[0m[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3musers[0m[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3musers[0m[32;1m[1;3mThe 'users' table is available in the database. Now, I need to construct a SQL query to count the number of users.
Action: sql_db_query_checker
Action Input: "SELECT COUNT(*) FROM users"[0m[32;1m[1;3mThe 'users' table seems to be the most relevant for this query. I should check its schema to confirm it has the necessary fields.
Action: sql_db_schema
Action Input: users[0m[33;1m[1;3m
CREATE TABLE users (
	id INTEGER, 
	nam