In [None]:
import re

class ChatDB:
    def __init__(self):
        self.query_patterns = {
            "total <A> by <B>": self._total_by_query,
            "average <A> by <B>": self._average_by_query,
            "count <A> by <B>": self._count_by_query,
            "find <A> where <B> is <C>": self._find_where_query
        }
        self.phrase_replacements = {
            "broken down by": "by",
        }

    def natural_language_to_sql(self, question, table_name, column_names):
        """
        Convert a natural language question into an SQL query for the given table and validate column names.
        """
        cleaned_question = self._preprocess_question(question)
        print(f"[DEBUG] Cleaned Question: {cleaned_question}")  # Debugging line
        for pattern, function in self.query_patterns.items():
            if self._match_pattern(cleaned_question, pattern):
                try:
                    print(f"[DEBUG] Matched Pattern: {pattern}")  # Debugging line
                    return function(cleaned_question, table_name, column_names)
                except Exception as e:
                    return f"Error occurred while processing query with pattern '{pattern}': {str(e)}"
        return f"Query pattern not recognized for question: '{question}'"

    def _find_closest_column_name(self, term, column_names):
        """
        Match the given term to the closest column name in the provided list.
        """
        from difflib import get_close_matches
        matches = get_close_matches(term.replace(' ', '_'), column_names, n=1, cutoff=0.5)
        return matches[0] if matches else term.replace(' ', '_')

    def _preprocess_question(self, question):
        question = question.lower()
        for phrase, replacement in self.phrase_replacements.items():
            question = question.replace(phrase, replacement)
        question = re.sub(r'\s+', ' ', question).strip()
        return question

    def _match_pattern(self, question, pattern):
        # Replace placeholders like <A>, <B>, <C> with regex groups
        regex = re.sub(r"<\w+>", r"([\\w\\s]+)", pattern)
        print(f"[DEBUG] Regex Pattern: {regex}")  # Debugging line
        return re.search(regex, question) is not None

    def _total_by_query(self, question, table_name, column_names):
        match = re.search(r"total ([\w\s]+) by ([\w\s]+)", question)
        if match:
            a = self._find_closest_column_name(match.group(1).strip(), column_names)
            b = self._find_closest_column_name(match.group(2).strip(), column_names)
            return f"SELECT `{b}`, SUM(`{a}`) FROM `{table_name}` GROUP BY `{b}`;"
        return "Invalid query format."

    def _average_by_query(self, question, table_name, column_names):
        match = re.search(r"average ([\w\s]+) by ([\w\s]+)", question)
        if match:
            a = self._find_closest_column_name(match.group(1).strip(), column_names)
            b = self._find_closest_column_name(match.group(2).strip(), column_names)
            return f"SELECT `{b}`, AVG(`{a}`) FROM `{table_name}` GROUP BY `{b}`;"
        return "Invalid query format."

    def _count_by_query(self, question, table_name, column_names):
        match = re.search(r"count ([\w\s]+) by ([\w\s]+)", question)
        if match:
            a = self._find_closest_column_name(match.group(1).strip(), column_names)
            b = self._find_closest_column_name(match.group(2).strip(), column_names)
            return f"SELECT `{b}`, COUNT(`{a}`) FROM `{table_name}` GROUP BY `{b}`;"
        return "Invalid query format."

    def _find_where_query(self, question, table_name, column_names):
        match = re.search(r"find ([\w\s]+) where ([\w\s]+) is ([\w\s]+)", question)
        if match:
            a = self._find_closest_column_name(match.group(1).strip(), column_names)
            b = self._find_closest_column_name(match.group(2).strip(), column_names)
            c = match.group(3).strip()  # No need to validate values
            return f"SELECT `{a}` FROM `{table_name}` WHERE `{b}` = '{c}';"
        return "Invalid query format."



# Example Usage
if __name__ == "__main__":
    chat_db = ChatDB()
    queries = [
        "Get total sales amount broken down by product category",
        "What is the average revenue by region",
        "Count customers by age group",
        "Find products where category is electronics"
    ]
    for query in queries:
        print(f"Query: {query}")
        result = chat_db.natural_language_to_sql(query)
        print(f"SQL: {result}")
        print()


Query: Get total sales amount broken down by product category
[DEBUG] Cleaned Question: get total sales amount by product category
[DEBUG] Regex Pattern: total ([\w\s]+) by ([\w\s]+)
[DEBUG] Matched Pattern: total <A> by <B>
SQL: SELECT `product_category`, SUM(`sales_amount`) FROM `sales` GROUP BY `product_category`;

Query: What is the average revenue by region
[DEBUG] Cleaned Question: what is the average revenue by region
[DEBUG] Regex Pattern: total ([\w\s]+) by ([\w\s]+)
[DEBUG] Regex Pattern: average ([\w\s]+) by ([\w\s]+)
[DEBUG] Matched Pattern: average <A> by <B>
SQL: SELECT `region`, AVG(`revenue`) FROM `sales` GROUP BY `region`;

Query: Count customers by age group
[DEBUG] Cleaned Question: count customers by age group
[DEBUG] Regex Pattern: total ([\w\s]+) by ([\w\s]+)
[DEBUG] Regex Pattern: average ([\w\s]+) by ([\w\s]+)
[DEBUG] Regex Pattern: count ([\w\s]+) by ([\w\s]+)
[DEBUG] Matched Pattern: count <A> by <B>
SQL: SELECT `age_group`, COUNT(`customers`) FROM `sales` GRO

In [31]:
# Example usage
chatdb = ChatDB()
question = "find total sales amount broken down by product category"
sql_query = chatdb.natural_language_to_sql(question)
print(sql_query)

# Additional examples
question2 = "find average sales amount by region"
sql_query2 = chatdb.natural_language_to_sql(question2)
print(sql_query2)

question3 = "count orders by customer"
sql_query3 = chatdb.natural_language_to_sql(question3)
print(sql_query3)

question4 = "find product where category is electronics"
sql_query4 = chatdb.natural_language_to_sql(question4)
print(sql_query4)


[DEBUG] Cleaned Question: find total sale amount broken product category
[DEBUG] Regex Pattern: total\ <A>\ by\ <B>
[DEBUG] Regex Pattern: average\ <A>\ by\ <B>
[DEBUG] Regex Pattern: count\ <A>\ by\ <B>
[DEBUG] Regex Pattern: find\ <A>\ where\ <B>\ is\ <C>
Query pattern not recognized for question: 'find total sales amount broken down by product category'
[DEBUG] Cleaned Question: find average sale amount region
[DEBUG] Regex Pattern: total\ <A>\ by\ <B>
[DEBUG] Regex Pattern: average\ <A>\ by\ <B>
[DEBUG] Regex Pattern: count\ <A>\ by\ <B>
[DEBUG] Regex Pattern: find\ <A>\ where\ <B>\ is\ <C>
Query pattern not recognized for question: 'find average sales amount by region'
[DEBUG] Cleaned Question: count order customer
[DEBUG] Regex Pattern: total\ <A>\ by\ <B>
[DEBUG] Regex Pattern: average\ <A>\ by\ <B>
[DEBUG] Regex Pattern: count\ <A>\ by\ <B>
[DEBUG] Regex Pattern: find\ <A>\ where\ <B>\ is\ <C>
Query pattern not recognized for question: 'count orders by customer'
[DEBUG] Cleaned

In [47]:
import re

def natural_language_to_sql(query):
    # Normalize the query (convert to lowercase and trim extra spaces)
    query = query.lower().strip()
    
    # Handle natural language phrases like "broken down by"
    query = re.sub(r"broken down by", "by", query)
    
    # Define the regex pattern for matching
    pattern = r"total (\w+(?: \w+)*) by (\w+(?: \w+)*)"
    
    # Check for a match
    match = re.search(pattern, query)
    if match:
        # Extract captured groups
        field = match.group(1).strip()
        group_by = match.group(2).strip()
        
        # Normalize field names (replace spaces with underscores)
        field = field.replace(" ", "_")
        group_by = group_by.replace(" ", "_")
        
        # Generate SQL query
        sql_query = f"SELECT {group_by}, SUM({field}) AS total_{field} FROM table_name GROUP BY {group_by};"
        return sql_query
    else:
        return "Query does not match the pattern."

# Example Test
query = "Get total sales amount broken down by product category"
result = natural_language_to_sql(query)
print(result)



SELECT product_category, SUM(sales_amount) AS total_sales_amount FROM table_name GROUP BY product_category;
