In [71]:
from dotenv import load_dotenv
from typing import List, Optional, Dict, Any
import os
import dspy
import mlflow
import pydantic
import json
from db_schema_example import example_json_schema
from dspy.datasets import DataLoader
from datasets import load_dataset
from dspy import Example, Prediction


In [181]:
db_schema = example_json_schema()

# Load environment variables
load_dotenv()
anthropic_api_key = os.getenv('anthropic_api_key')
gcloud_bq_project = os.getenv('gcloud_bq_project')

# MLflow configuration
mlflow.dspy.autolog()

# LM Configuration
lm = dspy.LM('claude-sonnet-4-5-20250929', 
             api_key=anthropic_api_key,
             temperature=0.0)
dspy.configure(lm=lm)
dspy.configure_cache(
    enable_disk_cache=False,
    enable_memory_cache=False,
)

# RULES = [
#     "Use CTEs for complex queries",
#     #"Use table aliases to shorten table names",
#     "Use column aliases for computed columns",
#     "Only use SQL features supported by the target dialect (given in schema)",
#     "Always use explicit JOINs instead of WHERE-based joins",
#     #"Always use fully qualified column names (table.column)",
#     "Ensure the SQL is syntactically correct",
#     "Ensure the SQL is semantically correct (e.g. GROUP BY columns)",
#     "Ensure the SQL addresses all subtasks in the decomposition",
#     "Ensure the SQL is safe against SQL injection (e.g. no direct string interpolation)",
#     "Ensure the SQL returns correct results for the question",
#     "Ensure the SQL is efficient and scalable (e.g. avoid SELECT *)",
#     "Ensure the SQL uses appropriate indexing (e.g. WHERE clauses on indexed columns)",
#     "Ensure the SQL uses appropriate aggregation functions (e.g. COUNT, SUM, AVG)",
#     "Ensure the SQL uses appropriate filtering (e.g. WHERE, HAVING)",
#     "Ensure the SQL uses appropriate sorting (e.g. ORDER BY)",
#     "Ensure the SQL uses appropriate grouping (e.g. GROUP BY)",
#     "Ensure the SQL uses appropriate joins (e.g. INNER JOIN, LEFT JOIN)",
#     "When you use a WHERE clause, use `WHERE 1 = 1` as the first condition to make appending conditions easier. Subsequent conditions should use AND/OR on a new row.",
# ]

#==============================#
#  ----  BigQuery Functions --- #
#==============================#
def clean_sql(sql: str) -> str:
    """Remove markdown formatting from SQL if present.

    Args:
        sql: The SQL query string, potentially with markdown formatting

    Returns:
        Cleaned SQL string without markdown
    """
    sql_clean = sql.strip()
    if sql_clean.startswith("```sql"):
        sql_clean = sql_clean.split("```sql")[1].split("```")[0].strip()
    elif sql_clean.startswith("```"):
        sql_clean = sql_clean.split("```")[1].split("```")[0].strip()
    return sql_clean

#==============================#
#  ----  1. Data Models   ---- #
#==============================#
class JoinSpec(pydantic.BaseModel):
    left: str
    right: str
    on: str
    type: str | None = pydantic.Field(default=None, description="join type if relevant")

class DbSchema(pydantic.BaseModel):
    tables: Dict[str, List[str]]
    primary_keys: Dict[str, str]
    columns: Dict[str, str] = pydantic.Field(
        description="The column names and their respective data types (e.g., INT, VARCHAR, DATE)"
    )
    joins: List[JoinSpec] = []

class SubProblem(pydantic.BaseModel):
    clause: str = pydantic.Field(
        description="SQL clause name like WHERE, GROUP BY, HAVING, ORDER BY"
    )
    goal: str = pydantic.Field(description="What this clause should accomplish")

class Decomposition(pydantic.BaseModel):
    subproblems: List[SubProblem] = pydantic.Field(
        description="List of subproblems with clause and goal"
    )

class QueryPlan(pydantic.BaseModel):
    steps: List[str] = pydantic.Field(description="Step-by-step plan to generate the query")
    aggregations: List[str] = pydantic.Field(description="List of aggregations to be used")
    filters: List[str] = pydantic.Field(description="List of filters to be applied")
    group_bys: List[str] = pydantic.Field(description="List of GROUP BY clauses to be used")
    order_bys: List[str] = pydantic.Field(description="List of ORDER BY clauses to be used")



#==============================#
# ---  2. DSPy Signatures ---  #
#==============================#
class schemaLinkingAgent(dspy.Signature):
    """Identify relevant tables/columns/joins for the question.
    Return STRICT structured JSON as DbSchema."""

    question: str = dspy.InputField()
    db_schema: str = dspy.InputField()
    schemaLink: DbSchema = dspy.OutputField()

class subproblemLinkingAgent(dspy.Signature):
    """You are a Subproblem Agent. Given the user request and the schema info, you decomposes the query into clause-level
    subproblems (e.g., WHERE, GROUP BY, JOIN, DISTINCT, ORDER BY, HAVING, EXCEPT, LIMIT, UNION)."""

    question: str = dspy.InputField()
    schemaLink: DbSchema = dspy.InputField()
    subProblems: Decomposition = dspy.OutputField()

class queryPlanAgent(dspy.Signature):
    """You are an SQL Planning Agent. Given the user request, schema info, and subproblems, you create a step-by-step plan
    to construct a query that will be used to solve the user's request. You produce only the procedural plan and are explicitly restricted from generating
    executable SQL at this stage."""

    question: str = dspy.InputField()
    schemaLink: DbSchema = dspy.InputField()
    subProblems: Decomposition = dspy.InputField()
    plan: QueryPlan = dspy.OutputField()

class sqlQueryAgent(dspy.Signature):
    """You are an SQL Query Agent. Given the user request, schema info, subproblems, and query plan, you generate the final executable SQL query.
    The SQL query must be valid and should accurately reflect the user's intent as outlined in the query plan.
    You must adhere to the provided rules to ensure the SQL is efficient, safe, and correct."""

    question: str = dspy.InputField()
    schemaLink: DbSchema = dspy.InputField()
    subProblems: Decomposition = dspy.InputField()
    plan: QueryPlan = dspy.InputField()
    #rules: List[str] = dspy.InputField()
    sql: str = dspy.OutputField(
        desc=(
            "The final executable SQL query." \
            #"In markdown format with SQL syntax highlighting."
        )
    )

class sqlCorrectionAgent(dspy.Signature):
    """You are a SQL Correction Agent. Given an SQL query and errors encountered during validation, you analyze the errors and produce a corrected SQL query.
    Your goal is to fix the SQL so it executes successfully and meets the user's original intent."""
    #You may use the dry_run_bigquery tool to validate your corrections before returning the final SQL."""

    sql: str = dspy.InputField()
    errors: List[str] = dspy.InputField()
    correctedSQL: str = dspy.OutputField(
        desc="The revised SQL query that fixes all validation errors."
    )

#===============================#
# --- 3. TextToSQL Pipeline --- #
#===============================#

class SQLOfThought(dspy.Module):
    def __init__(self):
        super().__init__()

        self.schemaAgent = dspy.Predict(schemaLinkingAgent)
        self.subproblemAgent = dspy.Predict(subproblemLinkingAgent)
        self.planAgent = dspy.Predict(queryPlanAgent)
        self.sqlAgent = dspy.Predict(sqlQueryAgent)
        self.correctionAgent = dspy.Predict(sqlCorrectionAgent)

    def forward(self, question, db_schema):
        
        schemaInfo = self.schemaAgent(question=question, db_schema=db_schema)

        subproblems = self.subproblemAgent(
            schemaLink=schemaInfo.schemaLink,
            question=question
        )

        queryPlan = self.planAgent(
            schemaLink=schemaInfo.schemaLink,
            question=question,
            subProblems=subproblems.subProblems
        )

        sqlQuery = self.sqlAgent(
            schemaLink=schemaInfo.schemaLink,
            question=question,
            subProblems=subproblems.subProblems,
            rules=RULES,
            plan=queryPlan.plan
        )
        
        return dspy.Prediction(sql=sqlQuery.sql)
        
sqlofthought = SQLOfThought()

In [182]:
## Load Training Dataset ##
dataset = load_dataset("gretelai/synthetic_text_to_sql")['train'][:10]

# Get the total number of rows to loop through
num_rows = len(dataset['sql']) 

examples = [
    dspy.Example(
        question=dataset['sql_prompt'][i],
        db_schema=dataset['sql_context'][i],
        gold_query=dataset['sql'][i]
    ).with_inputs("question", "db_schema")
    for i in range(num_rows)
]


In [175]:
class JudgeSQLQuality(dspy.Signature):
    """
##### SQL quality rubric
Criterion A: Adherence to Database SCHEMA
* Score = 4 Perfectly aligns with specified db_schema.
* Score = 3 Mostly aligns with specified db_schema with minor deviations.
* Score = 2 Moderate deviation from the db_schema.
* Score = 1 Significant deviations from the db_schema.
* Score = 0 Does not adhere to the db_schema.


Criterion B: Readability and Maintainability (Is the SQL code easy to understand and maintain?)
* Score = 4 The code is excellently formatted and thoroughly commented, uses meaningful aliases/variable names, ensuring high readability and ease of maintenance; organizes complex queries well.
* Score = 3 The code is well-formatted and commented, making it relatively easy to understand and maintain; uses aliases and names with some organization of complex queries.
* Score = 2 The code is somewhat readable with basic formatting and some comments, but improvements are needed; needs better use of aliases/names and organization.
* Score = 1 The code has minimal formatting and few comments, making it hard to understand; lacks meaningful names and organization.
* Score = 0 The code is unreadable, with no attempt at formatting or commenting.


Criterion C: Scalability (Does the solution scale well with larger datasets or more complex queries?)
* Score = 4 The solution is highly scalable, effortlessly handling large datasets and complex queries without performance degradation; avoids inefficient patterns like Cartesian joins.
* Score = 3 The solution scales well, maintaining performance with increased data volumes and complexity; minor areas for optimization.
* Score = 2 The solution is moderately scalable, handling larger datasets with some performance issues; misses some opportunities for using scalability practices.
* Score = 1 The solution shows poor scalability, with notable performance degradation under increased load; lacks effective scalability techniques.
* Score = 0 The solution does not scale; overlooks fundamental scalability practices, resulting in significant issues.


Criterion D: Compliance with Standards (Does the SQL query follow SQL standards and best practices?)
* Score = 4 The query strictly adheres to SQL standards and best practices, showcasing exemplary coding standards.
* Score = 3 The query closely follows SQL standards and adheres to many best practices.
* Score = 2 The query generally follows SQL standards but has room for better alignment with best practices.
* Score = 1 The query loosely follows SQL standards, with several deviations from best practices.
* Score = 0 The query does not follow SQL standards or best practices, using deprecated or non-standard syntax
    """

    question = dspy.InputField()
    db_schema = dspy.InputField()
    generated_query = dspy.InputField()
    gold_query = dspy.InputField()
    overall_score: float = dspy.OutputField(desc="Average score across all Criterion")

class LLMJudge(dspy.Module):
    def __init__(self):
        self.quality_judge = dspy.ChainOfThought(JudgeSQLQuality)

    def forward(self, gold, pred, trace=None):
        question = gold.question
        db_schema = gold.db_schema
        gold_query = gold.gold_query
        generated_query = pred.sql
        
        overall_score = self.quality_judge(
            question=question,
            db_schema=db_schema,
            gold_query=gold_query,
            generated_query=generated_query
        ).overall_score

        try:
            score = float(overall_score)
        except (ValueError, TypeError):
            score = 0.0
            
        return overall_score / 4

metric = LLMJudge()

In [183]:
evaluate = dspy.Evaluate(
    devset=examples,
    metric=metric,
    num_threads=32,
    display_table=True,
    display_progress=True,
    provide_traceback=True
)

evaluate(sqlofthought)

Average Metric: 9.45 / 10 (94.5%): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:37<00:00,  3.78s/it]

2025/10/07 23:32:51 INFO dspy.evaluate.evaluate: Average Metric: 9.453125 / 10 (94.5%)





Unnamed: 0,question,db_schema,gold_query,sql,LLMJudge
0,"What is the total volume of timber sold by each salesperson, sorte...","CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TE...","SELECT salesperson_id, name, SUM(volume) as total_volume FROM timb...","SELECT \n salesperson.name,\n SUM(timber_sales.volume) as to...",✔️ [0.969]
1,List all the unique equipment types and their corresponding total ...,"CREATE TABLE equipment_maintenance (equipment_type VARCHAR(255), m...","SELECT equipment_type, SUM(maintenance_frequency) AS total_mainten...","SELECT equipment_type, SUM(maintenance_frequency) AS total_mainten...",✔️ [1.000]
2,How many marine species are found in the Southern Ocean?,"CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR...",SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Oce...,SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean',✔️ [0.984]
3,What is the total trade value and average price for each trader an...,"CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(2...","SELECT trader_id, stock, SUM(price * quantity) as total_trade_valu...","SELECT \n trader_id,\n stock,\n SUM(price * quantity) AS ...",✔️ [1.000]
4,Find the energy efficiency upgrades with the highest cost and thei...,"CREATE TABLE upgrades (id INT, cost FLOAT, type TEXT); INSERT INTO...","SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDE...","SELECT cost, type\nFROM upgrades\nWHERE cost = (SELECT MAX(cost) F...",✔️ [1.000]
5,What is the total spending on humanitarian assistance by the Europ...,CREATE SCHEMA if not exists defense; CREATE TABLE if not exists eu...,SELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE...,SELECT SUM(spending) AS total_spending\nFROM defense.eu_humanitari...,✔️ [0.906]
6,What is the average water temperature for each fish species in Feb...,"CREATE TABLE SpeciesWaterTemp (SpeciesID int, Date date, WaterTemp...","SELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTem...","SELECT \n SpeciesID,\n AVG(WaterTemp) AS avg_water_temperatu...",✔️ [0.750]
7,Delete a program's outcome data,"CREATE TABLE Program_Outcomes (id INT, program_id INT, outcome_typ...",DELETE FROM Program_Outcomes WHERE program_id = 1002;,DELETE FROM Program_Outcomes\nWHERE program_id = ?,✔️ [1.000]
8,Find the total fare collected from passengers on 'Green Line' buses,"CREATE TABLE bus_routes (route_name VARCHAR(50), fare FLOAT); INSE...",SELECT SUM(fare) FROM bus_routes WHERE route_name = 'Green Line';,SELECT SUM(fare) AS total_fare\nFROM bus_routes\nWHERE route_name ...,✔️ [1.000]
9,What is the average property size in inclusive housing areas?,"CREATE TABLE Inclusive_Housing (Property_ID INT, Inclusive VARCHAR...",SELECT AVG(Property_Size) FROM Inclusive_Housing WHERE Inclusive =...,```sql\nSELECT AVG(Property_Size) AS average_property_size\nFROM I...,✔️ [0.844]


EvaluationResult(score=94.53, results=<list of 10 results>)