In [8]:
import re
import textwrap

# Define templates with proper indentation using textwrap.dedent
query_templates = [
    # Smallest value pattern
    (
        r"(?:find|list)?\s*smallest (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT {m.group(1)}, MIN({m.group(1)}) AS min_value
            FROM {m.group(2)}
            GROUP BY {m.group(1)}
            ORDER BY min_value ASC
            LIMIT 1;
        """).strip()
    ),
    
    # Largest value pattern
    (
        r"(?:find|list)?\s*largest (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT {m.group(1)}, MAX({m.group(1)}) AS max_value
            FROM {m.group(2)}
            GROUP BY {m.group(1)}
            ORDER BY max_value DESC
            LIMIT 1;
        """).strip()
    ),
    
    # Count entries in a table
    (
        r"(?:find|count)?\s*entries in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT COUNT(*) AS entry_count
            FROM {m.group(1)};
        """).strip()
    ),
    
    # Find all unique values
    (
        r"(?:find|list)?\s*unique (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT DISTINCT {m.group(1)}
            FROM {m.group(2)};
        """).strip()
    ),
]


# Define reusable keyword pattern
KEYWORDS_PATTERN = r"(?:find|list|determine|show|get|retrieve|give me\s*)?"

query_templates += [
    # Sum of a column
    (
        r"(?:find|calculate)?\s*sum of (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT SUM({m.group(1)}) AS total_sum
            FROM {m.group(2)};
        """).strip()
    ),
    
    # Average of a column
    (
        r"(?:find|calculate)?\s*average of (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT AVG({m.group(1)}) AS average_value
            FROM {m.group(2)};
        """).strip()
    ),
    
    # Minimum value in a column
    (
        r"(?:find|list)?\s*minimum (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT MIN({m.group(1)}) AS min_value
            FROM {m.group(2)};
        """).strip()
    ),
    
    # Maximum value in a column
    (
        r"(?:find|list)?\s*maximum (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT MAX({m.group(1)}) AS max_value
            FROM {m.group(2)};
        """).strip()
    ),
    
    # Count distinct values in a column
    (
        r"(?:find|count)?\s*distinct (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT COUNT(DISTINCT {m.group(1)}) AS distinct_count
            FROM {m.group(2)};
        """).strip()
    ),
    
    # Find all rows where a column equals a value
    (
        r"(?:find|list)?\s*rows where (.+) equals (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT *
            FROM {m.group(3)}
            WHERE {m.group(1)} = {m.group(2)};
        """).strip()
    ),
    
    # Find all rows where a column is greater than a value
    (
        r"(?:find|list)?\s*rows where (.+) greater than (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT *
            FROM {m.group(3)}
            WHERE {m.group(1)} > {m.group(2)};
        """).strip()
    ),
    
    # Find all rows where a column is less than a value
    (
        r"(?:find|list)?\s*rows where (.+) less than (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT *
            FROM {m.group(3)}
            WHERE {m.group(1)} < {m.group(2)};
        """).strip()
    ),
    
    # Count rows with a specific condition
    (
        r"(?:find|count)?\s*rows where (.+) equals (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT COUNT(*) AS row_count
            FROM {m.group(3)}
            WHERE {m.group(1)} = {m.group(2)};
        """).strip()
    ),
    
    # Find the top N rows by a column
    (
        rf"{KEYWORDS_PATTERN}\s*top (\d+) rows by (.+) in (.+)",
        lambda m: textwrap.dedent(f"""
            SELECT *
            FROM {m.group(3)}
            ORDER BY {m.group(2)} DESC
            LIMIT {m.group(1)};
        """).strip()
    ),
]




# Function to match and generate queries
def generate_query(user_query, templates):
    for pattern, query_func in templates:
        match = re.match(pattern, user_query, re.IGNORECASE)
        if match:
            return query_func(match)
    return None

In [9]:
# Test queries
test_queries = [
    "find smallest grade in enrollments",        # Smallest value
    "list largest credithours in courses",       # Largest value
    "count entries in students",                 # Count entries
    "list unique major in students",             # Unique values
    "calculate sum of salary in employees",      # Sum of a column
    "calculate average age in students",         # Average of a column
    "list minimum age in employees",             # Minimum value
    "list maximum salary in employees",          # Maximum value
    "count distinct department in employees",    # Count distinct values
    "list rows where grade equals A in enrollments", # Rows matching condition
    "list rows where age greater than 21 in students", # Rows with greater condition
    "list rows where salary less than 50000 in employees", # Rows with less condition
    "count rows where major equals CS in students",  # Count rows with condition
    "find top 5 rows by salary in employees",        # Top N rows
    "list top 34 rows by credithours in courses",      # Top N rows
    
    "retrieve top 34 rows by credithours in courses",
    "give me top 34 rows by credithours in courses"
]

In [10]:
# Process test queries IF JOIN IS NOT NEEDED
for user_query in test_queries:
    sql_query = generate_query(user_query, query_templates)
    if sql_query:
        print(f"User Query: {user_query}")
        print(f"Generated SQL Query:\n{sql_query}\n")
    else:
        print(f"No match found for query: {user_query}")


User Query: find smallest grade in enrollments
Generated SQL Query:
SELECT grade, MIN(grade) AS min_value
FROM enrollments
GROUP BY grade
ORDER BY min_value ASC
LIMIT 1;

User Query: list largest credithours in courses
Generated SQL Query:
SELECT credithours, MAX(credithours) AS max_value
FROM courses
GROUP BY credithours
ORDER BY max_value DESC
LIMIT 1;

User Query: count entries in students
Generated SQL Query:
SELECT COUNT(*) AS entry_count
FROM students;

User Query: list unique major in students
Generated SQL Query:
SELECT DISTINCT major
FROM students;

User Query: calculate sum of salary in employees
Generated SQL Query:
SELECT SUM(salary) AS total_sum
FROM employees;

No match found for query: calculate average age in students
User Query: list minimum age in employees
Generated SQL Query:
SELECT MIN(age) AS min_value
FROM employees;

User Query: list maximum salary in employees
Generated SQL Query:
SELECT MAX(salary) AS max_value
FROM employees;

User Query: count distinct depar