In [7]:
import pandas as pd
import re

def analyze_query(query):
  """
  Analyzes an SQL query string and identifies query types and nesting level.

  Args:
      query: The SQL query string.

  Returns:
      A dictionary with keys indicating presence of query types (join, where, etc.), nesting level, and a flag for schema awareness check (if implemented).
  """
  analysis = {
      "has_join": bool(re.search(r"\bJOIN\b", query, flags=re.IGNORECASE)),
      "has_where": bool(re.search(r"\bWHERE\b", query, flags=re.IGNORECASE)),
      "has_groupby": bool(re.search(r"\bGROUP BY\b", query, flags=re.IGNORECASE)),
      "has_aggregate": any(word in query for word in ["COUNT", "SUM", "AVG", "MAX", "MIN"]),
      "nesting_level": 0,
      "schema_aware": False  # Set to True if schema awareness check is implemented
  }



  # Implement logic to calculate nesting level based on parentheses or keywords (e.g., WHERE, JOIN)
  # This is a simplified example, adapt it for your specific SQL dialect syntax
  subquery_keywords = ["SELECT", "WHERE"]
  for keyword in subquery_keywords:
          open_subqueries = query.count(f"{keyword}")
          # Divide by 2 to account for double counting due to opening and closing keywords in subqueries
          analysis["nesting_level"] += max(int(open_subqueries / 2), 0)  # Ensure integer division
  return analysis

def extract_table_names(query):
  """
  Extracts table names from an SQL query string.

  Args:
      query: The SQL query string.

  Returns:
      A list of unique table names found in the query.
  """
  tables = re.findall(r"\bFROM\s+([^\s,;]+)", query, flags=re.IGNORECASE)
  return list(set(tables))

def extract_columns(query):
  """
  Extracts column names from the SELECT clause of an SQL query string.

  Args:
      query: The SQL query string.

  Returns:
      A list of column names from the query.
  """
  columns = re.findall(r"\bSELECT\s+(.+?)\bFROM", query, flags=re.DOTALL | re.IGNORECASE)
  if columns:
    return [col.strip() for col in columns[0].split(",")]
  else:
    return []


def check_schema_compliance(query, schema):
    """
    Checks if the SQL query references tables and columns that exist in the schema.

    Args:
        query: The SQL query string.
        schema: A dictionary representing the database schema.

    Returns:
        Boolean indicating if the query is schema compliant.
    """
    tables = extract_table_names(query)
    columns = extract_columns(query)

    # Check if all referenced tables exist in the schema
    for table in tables:
        if table not in schema:
            return False

    # Check if all referenced columns exist in the corresponding tables in the schema
    for column in columns:
        if "." in column:
            table, col = column.split(".")
            if table not in schema or col not in schema[table]:
                return False
        else:
            # If columns are not qualified with table names, assume they could belong to any table
            if not any(column in cols for cols in schema.values()):
                return False

    return True



In [12]:
schema = {
"Contract" : [
"Expiration Date","Supplier",
"ID","TCV","Term Type",
"Reporting Currency",
"Status","Title","Document Type",
"Effective Date",
"Functions","Services",
"Regions","Countries",
"Time Zone","Currencies",
"Agreement Type",
"Name","Source Name/Title"
],


"Contract Draft Request" : [
"ID","Title",
"Suppliers","ESignature Status",
"Source Name/Title","Total Deviations",
"Effective Date","TCV","Paper Type",
"Status","Regions","Countries",
"Functions","Services","Templates",
"Counterparty Type","Agreement Type",
"Expiration Date","Multilingual",
"No Touch Contract"
],

"CO/CDR" : [
"Created On","Created By",
"Counterparty","Reporting Date"
]
}

In [13]:
df = pd.read_csv(r"/Users/shrini/work/sirion/data/eval_query.csv")
# Assuming your dataframe is called 'df' with columns 'natural_language' and 'sql_query'
df["query_analysis"] = df["sql_query"].apply(analyze_query)

# Assuming your dataframe is called 'df' with columns 'natural_language' and 'sql_query'
df["query_analysis"] = df["sql_query"].apply(analyze_query)
df["has_join"] = df["query_analysis"].apply(lambda x: x["has_join"])
df["has_where"] = df["query_analysis"].apply(lambda x: x["has_where"])
df["has_groupby"] = df["query_analysis"].apply(lambda x: x["has_groupby"])
df["has_aggregate"] = df["query_analysis"].apply(lambda x: x["has_aggregate"])

# Extract table names
df["table_names"] = df["sql_query"].apply(extract_table_names)
all_tables = df["table_names"].explode().unique()

# Extract columns
df["columns"] = df["sql_query"].apply(extract_columns)
all_columns = df["columns"].explode().unique()




# Analyze number of column references
df["num_columns"] = df["columns"].apply(len)
num_columns_distribution = df["num_columns"].value_counts().sort_index()

# Nesting Level Analysis
df["nesting_level"] = df["query_analysis"].apply(lambda x: x["nesting_level"])#["nesting_level"]
nesting_level_distribution = df["nesting_level"].value_counts().sort_index()



# Schema Awareness
df["schema_aware"] = df["sql_query"].apply(lambda query: check_schema_compliance(query, schema))


# Schema Awareness (**Requires access to schema definition**)
#  - Replace this comment with your implementation to compare referenced tables/columns with actual schema
#  df["schema_aware"] = df.apply(lambda row: check_schema_compliance(row["sql_query"], schema), axis=1)




In [14]:
df.head()

Unnamed: 0,natural_language,Entity,sql_query,query_analysis,has_join,has_where,has_groupby,has_aggregate,table_names,columns,num_columns,nesting_level,schema_aware
0,List all contracts that are in draft stage) an...,Contract Draft Request,"SELECT agreement_type, title\nFROM CDR;","{'has_join': False, 'has_where': False, 'has_g...",False,False,False,False,[CDR],"[agreement_type, title]",2,0,False
1,What are the document types of contracts creat...,Contract,SELECT DISTINCT document_type\nFROM contracts\...,"{'has_join': False, 'has_where': True, 'has_gr...",False,True,False,False,[contracts],[DISTINCT document_type],1,0,False
2,What is the total number of contracts by Count...,Contract,"SELECT COUNT(*) AS total_contracts, supplier\n...","{'has_join': False, 'has_where': False, 'has_g...",False,False,False,True,[contract],"[COUNT(*) AS total_contracts, supplier]",2,0,False
3,What is our spending with IT vendors this year...,Contract,"SELECT\n SUM(tcv) AS spending_this_year,\n ...","{'has_join': False, 'has_where': True, 'has_gr...",False,True,False,True,[contract],"[SUM(tcv) AS spending_this_year, (SELECT SUM(t...",2,2,False
4,"""Show me the payment terms of the contract ID ...",Contract,SELECT *\nFROM contract\nWHERE id = 'CO200456'...,"{'has_join': False, 'has_where': True, 'has_gr...",False,True,False,False,[contract],[*],1,0,False


In [10]:
print("Query Type Counts:")
#print(df["query_analysis"].mean())  # Average presence of each query type

import pandas as pd
import re


print("Query Type Counts:")
print(df[["has_join", "has_where", "has_groupby", "has_aggregate"]].mean())  # Average presence of each query type

print("\nUnique Table Names:")
print(all_tables)

print("\nUnique Columns Across All Queries:")
print(all_columns)

print("\nDistribution of Number of Columns Referenced:")
print(num_columns_distribution)

print("\nDistribution of Nesting Levels:")
print(nesting_level_distribution)

print("\nSchema Awareness Check:")
print(df["schema_aware"].value_counts())


Query Type Counts:
Query Type Counts:
has_join         0.000000
has_where        0.862745
has_groupby      0.049020
has_aggregate    0.098039
dtype: float64

Unique Table Names:
['CDR' 'contracts' 'contract' 'contract_draft_request' nan]

Unique Columns Across All Queries:
['agreement_type' 'title' 'DISTINCT document_type'
 'COUNT(*) AS total_contracts' 'supplier' 'SUM(tcv) AS spending_this_year'
 '(SELECT SUM(tcv)' '*' 'payment_terms' 'COUNT(*) AS contract_count'
 'COUNT(*) AS cdr_count' 'SUM(tcv) AS total_contract_value' 'regions'
 'SUM(tcv) AS total_tcv' 'term_type' 'document_type' 'COUNT(*) AS count'
 'functions' 'AVG(tcv) AS average_value_of_contracts' 'services' 'tcv'
 'keywords' 'SUM(tcv) AS revenue'
 'DISTINCT regions AS state_governing_laws' 'expiration_date'
 "(COUNT(CASE WHEN status = 'renewed' THEN 1 END) * 100.0 / COUNT(*)) AS percentage_renewed"
 "(COUNT(CASE WHEN status = 'terminated' THEN 1 END) * 100.0 / COUNT(*)) AS percentage_terminated"
 'COUNT(*) AS contracts_with_