# SQL Validation and correction experiments

In [None]:
%pip install requests
%pip install sqlglot


## Test SQL Validation using sqlglot library

In [None]:
import sqlglot



def validate_sql_statement(sql_statement):
    try:
        # Attempt to parse and validate the SQL statement
        sqlglot.parse_one(sql_statement, dialect="spark")
        print(f"{sql_statement}: Parsed OK")
        return True
    except sqlglot.errors.ParseError as e:
        print(f"SQL Parse Error: {e}")
        return False

# this is correct
validate_sql_statement("SELECT * FORM table_name")

# why doesn't this fail?
validate_sql_statement("SELCT * FORM table_name")

## Attempt to execute SQL to test validity
https://sqlglot.com/sqlglot.html#parser-errors

In [None]:
from sqlglot.executor import execute

tables = {
    "sushi": [
        {"id": 1, "price": 1.0},
        {"id": 2, "price": 2.0},
        {"id": 3, "price": 3.0},
    ],
    "order_items": [
        {"sushi_id": 1, "order_id": 1},
        {"sushi_id": 1, "order_id": 1},
        {"sushi_id": 2, "order_id": 1},
        {"sushi_id": 3, "order_id": 2},
    ],
    "orders": [
        {"id": 1, "user_id": 1},
        {"id": 2, "user_id": 2},
    ],
}

execute(
    """
    SELECT
      o.user_id,
      SUM(s.price) AS price
    FROM orders o
    JOIN order_items i
      ON o.id = i.order_id
    JOIN sushi s
      ON i.sushi_id = s.id
    GROUP BY o.user_id
    """,
    tables=tables
)

In [None]:
execute(
    """
    SELECT * FOM orders
    """,
    tables=tables
)

## Class to check and correct SQL using OpenAI
Using a simple prompt and a 3 tries loop to check and potentially update the SQL given.

In [5]:
import sqlglot
import requests
import os

class SQLValidator:
    def __init__(self, azure_openai_endpoint, azure_openai_api_key, model="gpt-4", max_attempts=5):
        """
        Initialize the SQLValidator with Azure OpenAI endpoint, API key, and max attempts.
        """
        self.azure_openai_endpoint = azure_openai_endpoint
        self.azure_openai_api_key = azure_openai_api_key
        self.model = model
        self.max_attempts = max_attempts

    def validate_sql(self, sql_statement):
        """
        Validate the SQL statement using sqlglot.
        Returns True if valid, otherwise False.
        """
        try:
            # Attempt to parse and validate the SQL statement
            sqlglot.parse_one(sql_statement, dialect="spark")
            return True
        except sqlglot.errors.ParseError as e:
            print(f"SQL Parse Error: {e}")
            return False

    def amend_sql(self, invalid_sql):
        """
        Send the invalid SQL to Azure OpenAI for amendment.
        Returns the amended SQL statement.
        """
        headers = {
            "Content-Type": "application/json",
            "api-key": self.azure_openai_api_key
        }
        payload = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": "You are an assistant that fixes SQL syntax errors. Only return the SQL and no other content"},
                {"role": "user", "content": f"Fix the following SQL: {invalid_sql}"}
            ]
        }

        response = requests.post(self.azure_openai_endpoint, json=payload, headers=headers)

        if response.status_code == 200:
            amended_sql = response.json().get("choices", [])[0].get("message", {}).get("content", "").strip()
            return amended_sql
        else:
            raise Exception(f"Azure OpenAI API Error: {response.status_code} - {response.text}")

    def process_sql(self, sql_statement):
        """
        Validate and amend the SQL statement until it passes validation or max attempts are reached.
        """
        attempts = 0
        while attempts < self.max_attempts:
            if self.validate_sql(sql_statement):
                print("SQL is valid!")
                return sql_statement

            print(f"Attempt {attempts + 1}/{self.max_attempts}: SQL is invalid. Sending to Azure OpenAI for amendment...")
            sql_statement = self.amend_sql(sql_statement)
            print(f"Amended SQL: {sql_statement}")
            attempts += 1

        raise Exception("Maximum attempts reached. SQL could not be validated.")


StatementMeta(, 502abb4d-e265-4189-af4e-aa93c0d07d25, 15, Finished, Available, Finished)

## Grab keys from the environment
Not sure how to do this in a Fabric notebook. More work needed.

In [None]:
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")

if not AZURE_OPENAI_ENDPOINT or not AZURE_OPENAI_API_KEY:
    raise EnvironmentError("Please set the AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY environment variables.")


## My Keys to OpenAI
Need to remove these and make environment variables

In [None]:
AZURE_OPENAI_ENDPOINT = "https://MY-ENDPOINT.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2025-01-01-preview"
AZURE_OPENAI_API_KEY = "XXXX"


## Validate and correct the given SQL

In [None]:
sql_validator = SQLValidator(AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, max_attempts=3)

sql = "SELECT * FORM table_name"  # Example of invalid SQL
try:
    valid_sql = sql_validator.process_sql(sql)
    print(f"Final Valid SQL: {valid_sql}")
except Exception as e:
    print(f"Error: {e}")

## Using Spark SQL to get the lakehouse schema
Can we query the target lakehouse schema so we can potentially use this to ground the OpenAI query when fixing SQL errors?

In [None]:
from notebookutils import mssparkutils

# abfss://BEAR_DEV@onelake.dfs.fabric.microsoft.com/ben_testing.Lakehouse/Tables/tgt/organizationwebsite

tables_path = "abfss://BEAR_DEV@onelake.dfs.fabric.microsoft.com/ben_testing.Lakehouse/Tables/tgt/"
tables = mssparkutils.fs.ls(tables_path)
for table in tables:
    print(table.name)


## Query the lakehouse and generate a JSON schema document

In [None]:
import json
from notebookutils import mssparkutils

def getLakehouseSchema(lakehousePath):
    tables = mssparkutils.fs.ls(lakehousePath)

    table_schemas = {}

    for table in tables:
        table_name = table.name
        if "." not in table_name:  # Filter out non-table items based on the absence of a file extension
            try:
                df = spark.read.format("delta").load(f"{lakehousePath}{table_name}")
                schema = [field.name for field in df.schema.fields]
                table_schemas[table_name] = schema
            except Exception as e:
                table_schemas[table_name] = f"Error reading schema: {str(e)}"

    # Convert to JSON string
    #grounding_json = json.dumps(table_schemas, indent=4)
    return table_schemas #grounding_json

schema = getLakehouseSchema("abfss://BEAR_DEV@onelake.dfs.fabric.microsoft.com/ben_testing.Lakehouse/Tables/Party/")
print(schema)

### Test to see if we can use this output with SQLGLOT
Looks like I need some data to really use this method effectively

In [None]:
from sqlglot.executor import execute

tables = {
    "organizationcontactdetail": [
        {"PartyPermId": 1, "SourceTypePermId": 1.0}
    ],
    "order_items": [
        {"sushi_id": 1, "order_id": 1},
        {"sushi_id": 1, "order_id": 1},
        {"sushi_id": 2, "order_id": 1},
        {"sushi_id": 3, "order_id": 2},
    ],
    "orders": [
        {"id": 1, "user_id": 1},
        {"id": 2, "user_id": 2},
    ],
}

execute("SELECT PartyPermId, SourceTypePermId FROM organizationcontactdetail",tables=tables)

## Use the grounding_json with OpenAI

Showing some errors, because we need to firstly test for correct SQL and then to validate that the SQL matches the schema. There needs to be some thought about whether the validation part belongs in SQLGLOT or OpenAI

May need to be tidied up.

Can we cache the schema?

In [None]:
import sqlglot
import requests
import os

class GroundedSQLValidator:
    def __init__(self, azure_openai_endpoint, azure_openai_api_key, model="gpt-4", max_attempts=5):
        """
        Initialize the GroundedSQLValidator with Azure OpenAI endpoint, API key, and max attempts.
        """
        self.azure_openai_endpoint = azure_openai_endpoint
        self.azure_openai_api_key = azure_openai_api_key
        self.model = model
        self.max_attempts = max_attempts

    def validate_sql(self, sql_statement):
        """
        Validate the SQL statement using sqlglot.
        Returns True if valid, otherwise False.
        """
        try:
            # Attempt to parse and validate the SQL statement
            sqlglot.parse_one(sql_statement, dialect="spark")
            return True
        except sqlglot.errors.ParseError as e:
            print(f"SQL Parse Error: {e}")
            return False

    def amend_sql(self, invalid_sql):
        """
        Send the invalid SQL to Azure OpenAI for amendment.
        Returns the amended SQL statement.
        """
        headers = {
            "Content-Type": "application/json",
            "api-key": self.azure_openai_api_key
        }
        lakehouse_schema = getLakehouseSchema("abfss://BEAR_DEV@onelake.dfs.fabric.microsoft.com/ben_testing.Lakehouse/Tables/Party/")
        payload = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": "You are an assistant that fixes SQL syntax errors. Only return the SQL and no other content"},
                {"role": "user", "content": f"Using this JSON representation of the SQL schema {lakehouse_schema}, fix the following SQL to conform to the given schema: {invalid_sql}"}
            ]
        }

        response = requests.post(self.azure_openai_endpoint, json=payload, headers=headers)

        if response.status_code == 200:
            amended_sql = response.json().get("choices", [])[0].get("message", {}).get("content", "").strip()
            return amended_sql
        else:
            raise Exception(f"Azure OpenAI API Error: {response.status_code} - {response.text}")

    def process_sql(self, sql_statement):
        """
        Validate and amend the SQL statement until it passes validation or max attempts are reached.
        """
        attempts = 0
        while attempts < self.max_attempts:
            if self.validate_sql(sql_statement):
                print("SQL is valid!")
                #return sql_statement
            # now test it conforms to the schema
            print(f"Attempt {attempts + 1}/{self.max_attempts}: SQL is invalid. Sending to Azure OpenAI for amendment...")
            sql_statement = self.amend_sql(sql_statement)
            print(f"Amended SQL: {sql_statement}")
            attempts += 1

        raise Exception("Maximum attempts reached. SQL could not be validated.")


## Try out the validator


In [None]:
sql_validator = GroundedSQLValidator(AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY, max_attempts=3)

sql = "SELECT PartyPermId, SourcTypePermId FROM organizationcontactdetail"  # Example of invalid SQL
try:
    valid_sql = sql_validator.process_sql(sql)
    print(f"Final Valid SQL: {valid_sql}")
except Exception as e:
    print(f"Error: {e}")