In [7]:
import pandas as pd
import re

def analyze_query(query):
  """
  Analyzes an SQL query string and identifies query types and nesting level.

  Args:
      query: The SQL query string.

  Returns:
      A dictionary with keys indicating presence of query types (join, where, etc.), nesting level, and a flag for schema awareness check (if implemented).
  """
  analysis = {
      "has_join": bool(re.search(r"\bJOIN\b", query, flags=re.IGNORECASE)),
      "has_where": bool(re.search(r"\bWHERE\b", query, flags=re.IGNORECASE)),
      "has_groupby": bool(re.search(r"\bGROUP BY\b", query, flags=re.IGNORECASE)),
      "has_aggregate": any(word in query for word in ["COUNT", "SUM", "AVG", "MAX", "MIN"]),
      "nesting_level": 0,
      "schema_aware": False  # Set to True if schema awareness check is implemented
  }



  # Implement logic to calculate nesting level based on parentheses or keywords (e.g., WHERE, JOIN)
  # This is a simplified example, adapt it for your specific SQL dialect syntax
  subquery_keywords = ["SELECT", "WHERE"]
  for keyword in subquery_keywords:
          open_subqueries = query.count(f"{keyword}")
          # Divide by 2 to account for double counting due to opening and closing keywords in subqueries
          analysis["nesting_level"] += max(int(open_subqueries / 2), 0)  # Ensure integer division
  return analysis

def extract_table_names(query):
  """
  Extracts table names from an SQL query string.

  Args:
      query: The SQL query string.

  Returns:
      A list of unique table names found in the query.
  """
  tables = re.findall(r"\bFROM\s+([^\s,;]+)", query, flags=re.IGNORECASE)
  return list(set(tables))

def extract_columns(query):
  """
  Extracts column names from the SELECT clause of an SQL query string.

  Args:
      query: The SQL query string.

  Returns:
      A list of column names from the query.
  """
  columns = re.findall(r"\bSELECT\s+(.+?)\bFROM", query, flags=re.DOTALL | re.IGNORECASE)
  if columns:
    return [col.strip() for col in columns[0].split(",")]
  else:
    return []


def check_schema_compliance(query, schema):
    """
    Checks if the SQL query references tables and columns that exist in the schema.

    Args:
        query: The SQL query string.
        schema: A dictionary representing the database schema.

    Returns:
        Boolean indicating if the query is schema compliant.
    """
    tables = extract_table_names(query)
    columns = extract_columns(query)

    # Check if all referenced tables exist in the schema
    for table in tables:
        if table not in schema:
            return False

    # Check if all referenced columns exist in the corresponding tables in the schema
    for column in columns:
        if "." in column:
            table, col = column.split(".")
            if table not in schema or col not in schema[table]:
                return False
        else:
            # If columns are not qualified with table names, assume they could belong to any table
            if not any(column in cols for cols in schema.values()):
                return False

    return True



In [12]:
# put your DB Schema here
schema = {
}

In [13]:
df = pd.read_csv(r"/eval_query.csv")
# Assuming your dataframe is called 'df' with columns 'natural_language' and 'sql_query'
df["query_analysis"] = df["sql_query"].apply(analyze_query)

# Assuming your dataframe is called 'df' with columns 'natural_language' and 'sql_query'
df["query_analysis"] = df["sql_query"].apply(analyze_query)
df["has_join"] = df["query_analysis"].apply(lambda x: x["has_join"])
df["has_where"] = df["query_analysis"].apply(lambda x: x["has_where"])
df["has_groupby"] = df["query_analysis"].apply(lambda x: x["has_groupby"])
df["has_aggregate"] = df["query_analysis"].apply(lambda x: x["has_aggregate"])

# Extract table names
df["table_names"] = df["sql_query"].apply(extract_table_names)
all_tables = df["table_names"].explode().unique()

# Extract columns
df["columns"] = df["sql_query"].apply(extract_columns)
all_columns = df["columns"].explode().unique()




# Analyze number of column references
df["num_columns"] = df["columns"].apply(len)
num_columns_distribution = df["num_columns"].value_counts().sort_index()

# Nesting Level Analysis
df["nesting_level"] = df["query_analysis"].apply(lambda x: x["nesting_level"])#["nesting_level"]
nesting_level_distribution = df["nesting_level"].value_counts().sort_index()



# Schema Awareness
df["schema_aware"] = df["sql_query"].apply(lambda query: check_schema_compliance(query, schema))


# Schema Awareness (**Requires access to schema definition**)
#  - Replace this comment with your implementation to compare referenced tables/columns with actual schema
#  df["schema_aware"] = df.apply(lambda row: check_schema_compliance(row["sql_query"], schema), axis=1)




In [None]:
df.head()

In [None]:
print("Query Type Counts:")
#print(df["query_analysis"].mean())  # Average presence of each query type

import pandas as pd
import re


print("Query Type Counts:")
print(df[["has_join", "has_where", "has_groupby", "has_aggregate"]].mean())  # Average presence of each query type

print("\nUnique Table Names:")
print(all_tables)

print("\nUnique Columns Across All Queries:")
print(all_columns)

print("\nDistribution of Number of Columns Referenced:")
print(num_columns_distribution)

print("\nDistribution of Nesting Levels:")
print(nesting_level_distribution)

print("\nSchema Awareness Check:")
print(df["schema_aware"].value_counts())
