In [1]:
import re 
import token

## old version: 

In [None]:
# SQL output format
## Number of attributes : CDC data = 12 ; FRED = 2 ; WEGOVY = 2  
'''
1) SELECT ... WHERE ... 

2) SELECT ... GROUP BY ... 
 2.1) SELECT ... GROUP BY ... HAVING ...

3) SELECT ... WHERE ... ORDER BY ... LIMIT ... 

4) Aggregation functions: 
 4.1) SELECT SUM()... FROM ...
 4.2) SELECT AVG()... FROM ...  
 4.3) SELECT MIN()... FROM ...
 4.4) SELECT MAX()... FROM ...
 4.5) SELECT COUNT()... FROM ...

5) Patterns of WHERE clauses -- 
 5.1) Basic comparison: 
    SELECT * FROM table_name WHERE column_name = 'value';
 
 5.2) Range Queries: 
    SELECT * FROM table_name WHERE column_name BETWEEN value1 AND value2;
    SELECT * FROM table_name WHERE column_name >= value;

 5.3) Pattern matching: 
    SELECT * FROM table_name WHERE column_name LIKE '%pattern%';
    SELECT * FROM table_name WHERE column_name LIKE '_A_';

 5.4) Checking for NULL: 
    SELECT * FROM table_name WHERE column_name IS NULL;
    SELECT * FROM table_name WHERE column_name IS NOT NULL;

 5.5) Using IN: 
    SELECT * FROM table_name WHERE column_name IN (value1, value2, value3);

 5.6) Logical condition:   
    SELECT * FROM table_name 
    WHERE condition1 AND condition2;

    SELECT * FROM table_name 
    WHERE condition1 OR condition2;

    SELECT * FROM table_name 
    WHERE NOT condition; 

 5.7) Using EXISTS
    SELECT * FROM table_name WHERE EXISTS (subquery);
 
 5.8) Subqueries 
    SELECT * FROM table_name WHERE column_name = (SELECT value FROM another_table WHERE condition);

 5.9) Conditional Aggregations 
    SELECT field, COUNT(*) FROM employees 
    GROUP BY field 
    HAVING COUNT(*) > 5;

 5.10) Case-Insentive Filtering 
    SELECT * FROM table_name WHERE UPPER(column_name) = 'VALUE';

'''


'\n1) SELECT ... WHERE ... \n\n2) SELECT ... GROUP BY ... \n2.1) SELECT ... GROUP BY ... HAVING ...\n\n3) SELECT ... WHERE ... ORDER BY ... LIMIT ... \n\n4.1) SELECT SUM()... FROM ...\n4.2) SELECT AVG()... FROM ...  \n4.3) SELECT MIN()... FROM ...\n4.4) SELECT MAX()... FROM ...\n4.5) SELECT COUNT()... FROM ...\n\n'

main: 

In [125]:
def parse(sample_nlq,target=None, tokens=2):  
    p1 = sample_nlq.split(" ")  
    if target is not None: 
        target_tokens = target.split(" ")
        len_target_tokens = len(target_tokens)   
        #target_idx_start = p1.index(target_tokens[0])
        target_idx_end = p1.index(target_tokens[-1])   
        return p1[target_idx_end+1:target_idx_end+tokens]         
    return p1   

class nlq_format_checker: 
    def __init__(self):
        self.all_att_list= {'FRED': ['date','income'], 
                            'WEGOVY': ['quarter','sales'], 
                            'STOCK':['DATE','NOV','LLY'],
                            'CDC': ['SUBTOPIC', 'SUBTOPIC_ID', 'CLASSIFICATION', 'CLASSIFICATION_ID', 'GROUP_NAME', 'GROUP_ID', 
                                    'SUBGROUP', 'SUBGROUP_ID', 'ESTIMATE_TYPE', 'ESTIMATE_TYPE_ID', 'TIME_PERIOD', 
                                    'TIME_PERIOD_ID', 'ESTIMATE', 'STANDARD_ERROR']} 
        self.all_table_list = ['FRED', 'WEGOVY', 'CDC', 'STOCK']  

        self.required_att_list = []         
        self.required_table_list = []
        self.context_group_set = []  

        self.subquery = ""
        self.where = ""
        self.groupby = "" 

        self.start_nlq = ["find", "group", "show", "retrieve", "calculate", "identify", "count", "list", "display"]
        self.att_nlq = [] 

    def parse(self, sample_nlq): 
        # Check whether the nlq inclues at least one valid table name and one attribute name 
        self.table_name_checker(sample_nlq) 
        self.att_name_checker(sample_nlq)

        # preset 
        context_group = {}
        # 1) SELECT ... FROM ... 
        pattern = r'\b(?:find|display|show|list|identify|retrieve|tell)\b'
        match = re.search(pattern, sample_nlq, re.IGNORECASE)
        if match:               
            #result = parse(sample_nlq, f"{match.group()}") 
            #print(result)
            #filter_select = [re.split(r'\b(?:in the)\b', result[-1] ) ] 
            #print(filter_select)
            idx_match  = sample_nlq.index(f"{match.group()}") 
            idx_from_start = re.search(r'\b(in the|from the)\b', sample_nlq).start() if re.search(r'\b(in the|from the)\b', sample_nlq) else -1
            idx_from_end = re.search(r'\b(in the|from the)\b', sample_nlq).end() if re.search(r'\b(in the|from the)\b', sample_nlq) else -1

            if idx_from_start ==-1: 
                return ValueError("Your query must include 'from <table_name>' or 'in the <table_name>'.")

            if idx_match < idx_from_start:
                 context_group["SELECT"] = parse(sample_nlq[:idx_from_start], f"{match.group()}",tokens=5) 
            else: 
                context_group["SELECT"] = parse(sample_nlq[idx_from_end+1:], f"{match.group()}",tokens=5)   
            context_group["FROM"] = parse(sample_nlq[idx_from_end+1:])[0]    
        else : 
            return ValueError("Your query must include at least one word: 'display','show','list','identify','retrieve','tell'.")

        # 2) WHERE clause
        pattern2 = r'\b(?:where)\b'
        match2 = re.search(pattern2, sample_nlq, re.IGNORECASE)
        if match2:  
            context_group["WHERE"] = parse(sample_nlq,f"{match.group()}")
        
        # 3) GROUP BY 
        pattern2 = r'\b(?:group|grouped)\b'
        match2 = re.search(pattern2, sample_nlq, re.IGNORECASE)
        if match2:  
            context_group["GROUP BY"] = parse(sample_nlq,"by") 

        # 4) order by
        pattern3 = r'\b(?:order|ordered|sorted|ascending|descending)\b'
        match3 = re.search(pattern3, sample_nlq, re.IGNORECASE)
        if match3:  
            context_group["ORDER BY"] = parse(sample_nlq,"by")
        pattern3 = r'\b(?:ascending|descending)\b'
        match3 = re.search(pattern3, sample_nlq, re.IGNORECASE) 
        if match3: 
            context_group["ORDER BY"].append(match3.group()[:4])

        # 5) Limit 
        pattern4 = r'\b(?:limit|only|top|highest|lowest)\b' 
        match4 = re.search(pattern4, sample_nlq, re.IGNORECASE) 
        if match4:
            if match4.group() in ['highest','lowest']:  
                context_group["LIMIT"] = '1' 
            elif match4: 
                context_group["LIMIT"] = parse(sample_nlq[match4.end()+1:])[0]  
        
        return context_group  


    def table_name_checker(self, nlq): 
        for i in self.all_table_list: 
            if i in nlq: 
                self.required_table_list.append(i) 
        if len(self.required_table_list) == 0: 
            raise ValueError("You missed a valid table name") 
            
    def att_name_checker(self, nlq): 
        for i in self.all_att_list: 
            if i in nlq: 
                self.required_att_list.append(i)
        if len(self.required_att_list) == 0: 
            raise ValueError("You missed a valid attribute name") 

    def add_template(self, sample_nlq):
        context_group = self.parse(sample_nlq)
        
        # Append the new context group to the context_group_set
        self.context_group_set.append(context_group)
    
    def convert_to_SQL(self): 
        # return SQL based on the self.context_group_set
        pass 

In [89]:
checker = nlq_format_checker()
checker.parse("Group all entries in the FRED dataset by state and display each state's population.") 

{'SELECT': ['each', "state's", 'population.'],
 'FROM': ['FRED'],
 'GROUP BY': ['state']}

In [None]:
checker.parse("Retrieve all records in the CDC data for patients aged between 20 and 30.") 

{'SELECT': ['all', 'records', ''], 'FROM': ['CDC']}

## New version:

In [None]:
import re 
import pandas as pd 

def puctuation_remover(sample_nlq): 
    nlq = sample_nlq.replace(",","")
    nlq = nlq.replace("!","")
    nlq = nlq.replace("'","")
    return  nlq 

class nlq_parser: 
    def __init__(self,sample_nlq=None):
        self.nlq = sample_nlq 
        self.full_table_list = ['FRED', 'CDC', 'STOCK']  
        self.full_att_list = {'FRED': ['date','income'], 
                            'STOCK':['DATE','NOV','LLY'],
                            'CDC': ['SUBTOPIC', 'SUBTOPIC_ID', 'CLASSIFICATION', 'CLASSIFICATION_ID', 'GROUP_NAME', 'GROUP_ID', 
                                    'SUBGROUP', 'SUBGROUP_ID', 'ESTIMATE_TYPE', 'ESTIMATE_TYPE_ID', 'TIME_PERIOD', 
                                    'TIME_PERIOD_ID', 'ESTIMATE', 'STANDARD_ERROR']} 
        
        self.nlq_start_token = ["find", "group", "show", "retrieve", "calculate", "identify", "count", "list", "display", "tell"]
        self.nlq_agg_token = ['sum', 'total', 'average', 'mean', 'max', 'min']
        self.nlq_tokens = None
        

    def check_atts(self, sample_nlq): 
        # Input: sample_nlq 
        # Output: Dictionary = Key: Attribute tokens, Value: index in NLQ
        sample_nlq = puctuation_remover(sample_nlq) 
        self.nlq_tokens = sample_nlq.split(" ")  
        atts_ = {} 
        for _, i in enumerate(self.full_att_list.values()): 
            #print(i) 
            for j in i : 
                if j in self.nlq_tokens:  
                    #print(self.nlq_tokens.index(j))
                    atts_[j] = self.nlq_tokens.index(j)   
        
        if len(atts_) == 0 : 
            return ValueError(f"You must include valid attribute names: {self.full_att_list}")  
        #print(atts_)
        return atts_ 

    def check_from(self, sample_nlq): 
        from_ = "FROM " 
        for i in self.full_table_list :
            if i in sample_nlq and i not in from_:
                from_ = from_ + i + ","
        
        from_ = from_[:-1]
        #print(from_)
        return from_ 

    def check_select(self, sample_nlq):
        cnt = 0
        pattern = r'\b(?:find|display|show|list|identify|retrieve|tell)\b'
        atts_position = self.check_atts(sample_nlq) #-- You can see the position index of attribute name in NLQ  
        match = re.search(pattern, sample_nlq, re.IGNORECASE)
        if match:    
            cnt += 1 

        if cnt ==0 : 
            return ValueError("Your query must include at least one word: 'display','show','list','identify','retrieve','tell'.")
        else:
            target_atts = self.adjacent_token(match.group(), atts_position) 
            return f"SELECT {target_atts}"
        
    def check_groupby(self, sample_nlq): 
        cnt = 0
        pattern = r'\b(?:group|group by|each)\b'
        atts_position = self.check_atts(sample_nlq) #-- You can see the position index of attribute name in NLQ   
        match = re.search(pattern, sample_nlq, re.IGNORECASE)
        if match:        
            target_att_by = self.adjacent_token(match.group(), atts_position) 
            return f"GROUP BY {target_att_by}"
        else:
            return     
        
    def check_aggregation(self, sample_nlq): 
        aggregation_functions = [] 
        for i in self.nlq_agg_token: 
            if i in sample_nlq : 
                aggregation_functions.append(i) 

    def adjacent_token(self, token:str, atts_position:dict, verbose=False):
        target_idx = self.nlq_tokens.index(token)
        min_distance = 99999
        adjacent_token = "" 
        for k, v in atts_position.items() :
            distance = abs(v-target_idx)
            if distance < min_distance:
                min_distance = distance 
                adjacent_token = k 
        if verbose== False:
            return adjacent_token
        else: 
            print(atts_position)
            print(f"Target({token}) index:",target_idx)
            return adjacent_token 


In [103]:
checker = nlq_parser()
nlq = "Find the total ESTIMATE for each TIME_PERIOD_ID in CDC." 
print(checker.check_atts(nlq) )

print(checker.check_select(nlq)) 
print(checker.check_from(nlq)) 

print(checker.check_groupby(nlq))
print(checker.check_aggregation(nlq)) 

{'TIME_PERIOD_ID': 6, 'ESTIMATE': 3}
SELECT ESTIMATE
FROM CDC
GROUP BY TIME_PERIOD_ID
None


In [None]:
checker = nlq_parser() 
nlq = sample_queries[2] 
print(checker.check_atts(nlq)) 

print(checker.check_select(nlq))  
print(checker.check_from(nlq)) 

print(checker.check_groupby(nlq)) 
print(checker.check_aggregation(nlq)) 

{'NOV': 2, 'LLY': 4}
SELECT NOV
FROM STOCK
None
None


In [105]:
sample_queries = [
    "Find the total income from FRED.",
    "Show the average income by date from FRED.",
    "List all NOV and LLY values from STOCK.",
    "Display the sum of ESTIMATE grouped by GROUP_NAME in CDC.",
    "Retrieve all records from FRED.",
    "Calculate the max STANDARD_ERROR for each ESTIMATE_TYPE in CDC.",
    "Identify the min income from FRED.",
    "Group the TIME_PERIOD by SUBGROUP in CDC.",
    "Tell me the mean income for each date in FRED.",
    "Count the number of records in STOCK.",
    "Find the total ESTIMATE for each TIME_PERIOD_ID in CDC.",
    "Show all records from CDC where CLASSIFICATION equals 'A'.",
    "List all available GROUP_NAME values from CDC.",
    "Display the max NOV value from STOCK.",
    "Retrieve SUBTOPIC and CLASSIFICATION grouped by SUBTOPIC_ID in CDC.",
    "Calculate the average income grouped by date in FRED.",
    "Identify all TIME_PERIOD and their respective ESTIMATE values from CDC.",
    "Find the sum of LLY values from STOCK.",
    "Show all GROUP_ID values in CDC where ESTIMATE_TYPE_ID equals 1.",
    "List all attributes from STOCK grouped by DATE."
]

## Example queries:

1) `SELECT ... WHERE ...`
   - **Example Query** "Find all records in the CDC data where the age is 65."

2) `SELECT ... GROUP BY ...`
   - **Example Query:** "Group all entries in the FRED dataset by state and display each state's population."

2.1) `SELECT ... GROUP BY ... HAVING ...`
   - **Example Query:** "Show all cities in the CDC data that have more than 10,000 cases, grouped by city."

3) `SELECT ... WHERE ... ORDER BY ... LIMIT ...`
   - **Example Query:** "Retrieve the top 5 states from the WEGOVY dataset where the usage is highest, sorted by usage in descending order."

4) Aggregation functions:

4.1) `SELECT SUM() ... FROM ...`
   - **Example Query:** "Calculate the total number of vaccinations recorded in the CDC data."

4.2) `SELECT AVG() ... FROM ...`
   - **Example Query:** "Find the average number of hospitalizations in each region of the FRED dataset."

4.3) `SELECT MIN() ... FROM ...`
   - **Example Query:** "Identify the minimum dosage amount in the WEGOVY dataset."

4.4) `SELECT MAX() ... FROM ...`
   - **Example Query:** "Find the highest population among all counties in the FRED dataset."

4.5) `SELECT COUNT() ... FROM ...`
   - **Example Query:** "Count the total number of entries in the CDC data."

5) Patterns of WHERE clauses:

5.1) Basic Comparison
   - **Example Query:** "List all records in the WEGOVY dataset where the dosage is exactly 2 mg."

5.2) Range Queries
   - **Example Query:** "Retrieve all records in the CDC data for patients aged between 20 and 30."
   - **Example Query:** "Find entries in the CDC dataset where the blood pressure reading is greater than 120."

5.3) Pattern Matching
   - **Example Query:** "List all counties in the FRED data that contain 'New' in their names."
   - **Example Query:** "Find all records in the CDC dataset where the region code is 'CA'."

5.4) Checking for NULL
   - **Example Query:** "Show all entries in the CDC data where the vaccination status is unknown."

5.5) Using IN
   - **Example Query:** "Display all records in the FRED dataset where the state is either 'California,' 'Texas,' or 'New York'."

5.6) Logical Conditions
   - **Example Query (AND):** "List all CDC records where age is above 50 and has received a vaccination."
   - **Example Query (OR):** "Retrieve FRED dataset records where the state is 'Texas' or population is above 1 million."
   - **Example Query (NOT):** "Find CDC data entries where the vaccination status is not 'complete'."

5.7) Using EXISTS
   - **Example Query:** "List all patients in the CDC dataset who have a recorded entry in the FRED dataset."

5.8) Subqueries
   - **Example Query:** "Retrieve all records in the CDC data where the population is equal to the maximum population recorded in the FRED dataset."

5.9) Conditional Aggregations
   - **Example Query:** "List all regions in the CDC data where there are more than 1,000 cases recorded, grouped by region."

5.10) Case-Insensitive Filtering
   - **Example Query:** "Find all CDC entries where the condition name matches 'covid' (case-insensitive)."


In [114]:
import re

class NLQtoSQLConverter:
    def __init__(self):
        # Define keywords and mappings to SQL clauses
        self.keywords = {
            'select': ['find', 'list', 'show', 'display', 'identify', 'retrieve', 'calculate'],
            'from': ['in the', 'from the', 'using the'],
            'where': ['where'],
            'group by': ['group by', 'grouped by'],
            'order by': ['ordered by', 'sorted by'],
            'limit': ['top', 'highest', 'lowest', 'only']
        }
        self.query_parts = {}

    def convert_nlq_to_sql(self, nlq):
        # Initialize the query structure
        self.query_parts = {
            'SELECT': '',
            'FROM': '',
            'WHERE': '',
            'GROUP BY': '',
            'ORDER BY': '',
            'LIMIT': ''
        }
        
        # Parse each clause
        self.extract_select_clause(nlq)
        self.extract_from_clause(nlq)
        self.extract_where_clause(nlq)
        self.extract_group_by_clause(nlq)
        self.extract_order_by_clause(nlq)
        self.extract_limit_clause(nlq)
        
        # Build and return the SQL query
        return self.build_sql_query()

    def extract_select_clause(self, nlq):
        # Adjust pattern to capture different variations
        pattern = r'\b(?:' + '|'.join(self.keywords['select']) + r')\b (.+?) \b(?:' + '|'.join(self.keywords['from']) + r')\b'
        match = re.search(pattern, nlq, re.IGNORECASE)
        if match:
            self.query_parts['SELECT'] = match.group(1).strip()
        else:
            print(f"DEBUG: No SELECT clause found in NLQ: '{nlq}'")
            raise ValueError("No SELECT clause found in the query.")

    def extract_from_clause(self, nlq):
        pattern = r'\b(?:' + '|'.join(self.keywords['from']) + r')\b (\w+)'
        match = re.search(pattern, nlq, re.IGNORECASE)
        if match:
            self.query_parts['FROM'] = match.group(1).strip()
        else:
            raise ValueError("No FROM clause found in the query.")

    def extract_where_clause(self, nlq):
        pattern = r'\bwhere\b (.+?)(?=\b(?:' + '|'.join(self.keywords['group by'] + self.keywords['order by'] + self.keywords['limit']) + r')\b|$)'
        match = re.search(pattern, nlq, re.IGNORECASE)
        if match:
            self.query_parts['WHERE'] = match.group(1).strip()

    def extract_group_by_clause(self, nlq):
        pattern = r'\b(?:group by|grouped by)\b (.+?)(?=\b(?:' + '|'.join(self.keywords['order by'] + self.keywords['limit']) + r')\b|$)'
        match = re.search(pattern, nlq, re.IGNORECASE)
        if match:
            self.query_parts['GROUP BY'] = match.group(1).strip()

    def extract_order_by_clause(self, nlq):
        pattern = r'\b(?:ordered by|sorted by)\b (.+?)(?=\b(?:' + '|'.join(self.keywords['limit']) + r')\b|$)'
        match = re.search(pattern, nlq, re.IGNORECASE)
        if match:
            self.query_parts['ORDER BY'] = match.group(1).strip()

    def extract_limit_clause(self, nlq):
        pattern = r'\b(?:top|highest|lowest|only)\b (\d+)'
        match = re.search(pattern, nlq, re.IGNORECASE)
        if match:
            self.query_parts['LIMIT'] = match.group(1).strip()

    def build_sql_query(self):
        # Assemble the SQL query string based on parsed components
        sql_query = f"SELECT {self.query_parts['SELECT']} FROM {self.query_parts['FROM']}"
        if self.query_parts['WHERE']:
            sql_query += f" WHERE {self.query_parts['WHERE']}"
        if self.query_parts['GROUP BY']:
            sql_query += f" GROUP BY {self.query_parts['GROUP BY']}"
        if self.query_parts['ORDER BY']:
            sql_query += f" ORDER BY {self.query_parts['ORDER BY']}"
        if self.query_parts['LIMIT']:
            sql_query += f" LIMIT {self.query_parts['LIMIT']}"
        
        return sql_query

# Example usage
converter = NLQtoSQLConverter()
for i in example_queries: 
    print(i) 
    sql_query = converter.convert_nlq_to_sql(i)
    print(sql_query)  
    print("-"*60)



Find all records in the CDC data where the age is 65.
SELECT all records FROM CDC WHERE the age is 65.
------------------------------------------------------------
Show all cities in the CDC data that have more than 10,000 cases, grouped by city.
SELECT all cities FROM CDC GROUP BY city.
------------------------------------------------------------
Retrieve the top 5 states from the WEGOVY dataset where the usage is highest, sorted by usage in descending order.
SELECT the top 5 states FROM WEGOVY WHERE the usage is ORDER BY usage in descending order. LIMIT 5
------------------------------------------------------------
Calculate the total number of vaccinations recorded in the CDC data.
SELECT the total number of vaccinations recorded FROM CDC
------------------------------------------------------------
Find the average number of hospitalizations in each region of the FRED dataset.
DEBUG: No SELECT clause found in NLQ: 'Find the average number of hospitalizations in each region of the FR

ValueError: No SELECT clause found in the query.