In [1256]:
import numpy as np
import nltk
import re
from nltk.stem import WordNetLemmatizer, PorterStemmer
from nltk.corpus import stopwords

In [1257]:
text="""Fetch names of student whose 10th marks are greater
than 85, 12th marks greater than 80 and aggregate is greater than 65
in descending order. """
# text="""count the number of students whose 10th marks are less than 35"""
# text = """fetch names of all states"""
# text = """fetch names of all states whose population is greater than 10000000 and less than 20000000"""
# text = """fetch first 3 names of all states whose population is between 10000000 and 20000000"""
# text = """what is the total population of state gujarat or rajasthan?"""
# text = """which city has the highest population?"""
# text = """which states has average population greater than 10000000?""" 

In [1258]:
ignore_list = ["the", "a", "are", "is", "to", "in", "than"]

In [1259]:
def process_partial_dict(partial_dict):
    lemmatizer = WordNetLemmatizer()
    replacer = {}
    for types in partial_dict_raw:
        for key in partial_dict_raw[types]:
            for word in partial_dict_raw[types][key]:
                replacer[lemmatizer.lemmatize(word.lower())] = (key.lower(), types.lower())
    return replacer

In [1260]:
partial_dict_raw = {
    "select":{
        "SELECT": ["SELECT", "FETCH", "GET", "SHOW", "LIST", "display", "which"]
    },
    "from":{
        "FROM": ["FROM", "IN", "OF"]
    },
    "selector":{
        "*": [ "ALL"]
    },
    "keyword" : {
        "WHERE": ["WHERE", "IF", "WHEN", "IF", "WHEN", "WHOSE"],
        "ORDER BY": ["ORDER BY", "SORT BY", "ORDER", "SORT"],
        "AND": ["AND", "BOTH", "TOGETHER"],
        "OR": ["OR", "EITHER", "ONE OF"],
        "NOT": ["NOT", "NEITHER", "NONE OF"],
        "LIMIT": ["LIMIT", "TOP", "FIRST"],
        "OFFSET": ["OFFSET", "SKIP", "AFTER"],
        "GROUP BY": ["GROUP BY", "GROUP"],
        "HAVING": ["HAVING", "WITH"],
        "DISTINCT": ["DISTINCT", "UNIQUE"]
    },
    "option":{
        "DESC": ["DESC", "DESCENDING", "DECREASING", "DECREASE"],
        "ASC": ["ASC", "ASCENDING", "INCREASING", "INCREASE"]
    },
    "function":{
        "AVG": ["AVG", "AVERAGE", "MEAN"],
        "COUNT": ["COUNT", "TOTAL"],
        "MAX": ["MAX", "MAXIMUM", "HIGHEST"],
        "MIN": ["MIN", "MINIMUM", "LOWEST"],
        "SUM": ["SUM", "TOTAL"]
    },
    "operator":{
        "LIKE": ["LIKE", "AS"],
        "IN": ["IN", "WITHIN"],
        "IS NULL": ["IS NULL", "IS EMPTY", "IS BLANK"],
        "IS NOT NULL": ["IS NOT NULL", "IS NOT EMPTY", "IS NOT BLANK"],
        "BETWEEN": ["BETWEEN", "IN BETWEEN", "WITHIN"],
        "=": ["=", "EQUAL TO", "IS"],
        ">": [">", "GREATER"],
        "<": ["<", "LESS"],
        ">=": [">=", "GREATER THAN EQUAL TO"],
        "<=": ["<=", "LESS THAN EQUAL TO"],
        "!=": ["!=", "NOT EQUAL TO"],
        "*": ["*", "MULTIPLY"],
    }
}

replacer_dict = process_partial_dict(partial_dict_raw)

In [1261]:
text = " ".join([x for x in text.split() if x not in ignore_list])

In [1262]:
# Tokenize the text
tokens = nltk.word_tokenize(text)
tokens

['Fetch',
 'names',
 'of',
 'student',
 'whose',
 '10th',
 'marks',
 'greater',
 '85',
 ',',
 '12th',
 'marks',
 'greater',
 '80',
 'and',
 'aggregate',
 'greater',
 '65',
 'descending',
 'order',
 '.']

In [1263]:
info = {
    'tables': ['students', 'marks'],
    'columns': {
        'students': [('id', 'number'), ('name', 'string'),  ('age', 'number'),  ('marks_id', 'number')],
        'marks': [('id', 'number'), ('10th', 'number'), ('12th', 'number'), ('aggregate', 'number')]
    },
    'foreign_keys': {
        'students': {
            'marks_id': 'marks'
        }
    }
}

# info = {
#     'tables': ['city', 'state'],
#     'columns': {
#         'city': [('id', 'number'), ('name', 'string'), ('population', 'number'), ('state_id', 'number')],
#         'state': [('id', 'number'), ('name', 'string')]
#     },
#     'foreign_keys': {
#         'city': {
#             'state_id': 'state'
#         }
#     }
# }

In [1264]:
selected_tables_dict = {}
for table in info['tables']:
    selected_tables_dict[table] = 0
selected_tables_dict

{'students': 0, 'marks': 0}

In [1265]:
lemmatizer = WordNetLemmatizer()

In [1266]:
def getSimilar(x, y):
    stemmer = PorterStemmer()
    x = stemmer.stem(x)
    for i in y:
        if x == stemmer.stem(i):
            return i
    return None

In [1267]:
def enrich(x, info, selected_tables_dict, check_all=False):
    #check for keyword
    y = getSimilar(x, replacer_dict.keys())
    if y is not None:
        return replacer_dict[y]

    # check if its table name
    y = getSimilar(x, info['tables'])
    if y is not None:
        selected_tables_dict[y] = 1
        return (y, 'table')
    # check if its column name

    _table_list = [x for x in selected_tables_dict if selected_tables_dict[x] == 1] if not check_all else info['tables']

    for table in _table_list:
        y = getSimilar(x, [_x[0] for _x in info['columns'][table]])
        if y is not None:
            selected_tables_dict[table] = 1
            return (f'{table}.{y}', 'column')
    
    # check if its a cardinal number
    if x.isdigit():
        return (x, 'number')
    
    # if strict check for type of attr
    if check_all:
        _table_list = [x for x in selected_tables_dict if selected_tables_dict[x] == 1] if not check_all else info['tables']
        for table in _table_list:
            for col in info['columns'][table]:
                if col[1] == 'string':
                    return (x, 'value', f'{col[0]}')

    return x


In [1268]:
def parse_tokens(tokens, check_all = False):
    for i in range(len(tokens)):
        if isinstance(tokens[i], tuple):
            continue
        review = tokens[i]
        review = re.sub('[^0-9a-zA-Z]', ' ', review)
        review = review.lower()
        review = lemmatizer.lemmatize(review, pos='v')
        review = enrich(review, info, selected_tables_dict, check_all=check_all)
        tokens[i] = review
    return tokens

tokens = parse_tokens(tokens)
tokens = parse_tokens(tokens)
tokens = parse_tokens(tokens, check_all=True)

tokens = [x for x in tokens if not x[0].startswith(' ') ]
print(selected_tables_dict,'\n',tokens)
# nltk.pos_tag(tokens)

{'students': 1, 'marks': 1} 
 [('select', 'select'), ('students.name', 'column'), ('from', 'from'), ('students', 'table'), ('where', 'keyword'), ('marks.10th', 'column'), ('marks', 'table'), ('>', 'operator'), ('85', 'number'), ('marks.12th', 'column'), ('marks', 'table'), ('>', 'operator'), ('80', 'number'), ('and', 'keyword'), ('marks.aggregate', 'column'), ('>', 'operator'), ('65', 'number'), ('desc', 'option'), ('order by', 'keyword')]


In [1269]:
WordNetLemmatizer().lemmatize('highest', pos='v'), PorterStemmer().stem('highest')

('highest', 'highest')

In [1270]:
def nlsql_partial_process(tokens):
    tokens = [x if x not in replacer_dict.keys() else (replacer_dict[x][0].upper(),replacer_dict[x][1]) for x in tokens ]
    return tokens

tokens = nlsql_partial_process(tokens)
tokens

[('select', 'select'),
 ('students.name', 'column'),
 ('from', 'from'),
 ('students', 'table'),
 ('where', 'keyword'),
 ('marks.10th', 'column'),
 ('marks', 'table'),
 ('>', 'operator'),
 ('85', 'number'),
 ('marks.12th', 'column'),
 ('marks', 'table'),
 ('>', 'operator'),
 ('80', 'number'),
 ('and', 'keyword'),
 ('marks.aggregate', 'column'),
 ('>', 'operator'),
 ('65', 'number'),
 ('desc', 'option'),
 ('order by', 'keyword')]

In [1271]:
for i in range(len(tokens)-1):
    if tokens[i][1] == 'option' and tokens[i+1][1] == 'keyword':
        swap = tokens[i]
        tokens[i] = tokens[i+1]
        tokens[i+1] = swap

# filter
new_tokens = []
no_table = False
no_selectors = False

for i in range(len(tokens)):
    if not isinstance(tokens[i], tuple):
        pass
    elif not no_table and tokens[i][0] == 'where':
        no_table = True
        new_tokens.append(tokens[i])
    elif not no_selectors and tokens[i][0] == 'from':
        no_selectors = True
        new_tokens.append(tokens[i])
    elif no_selectors and tokens[i][1] == 'selector':
        pass
    elif no_table and tokens[i][1] == 'table':
        pass

    else:
        new_tokens.append(tokens[i])

tokens = new_tokens
tokens

[('select', 'select'),
 ('students.name', 'column'),
 ('from', 'from'),
 ('students', 'table'),
 ('where', 'keyword'),
 ('marks.10th', 'column'),
 ('>', 'operator'),
 ('85', 'number'),
 ('marks.12th', 'column'),
 ('>', 'operator'),
 ('80', 'number'),
 ('and', 'keyword'),
 ('marks.aggregate', 'column'),
 ('>', 'operator'),
 ('65', 'number'),
 ('order by', 'keyword'),
 ('desc', 'option')]

In [1272]:
def final_query(tokens, selected_tables_dict):
    query = "SELECT{unique} {selectors} FROM {tables} WHERE"

    selectors = []
    selector_replacer = ""
    unique_val = ""

    tables = selected_tables_dict.keys()
    tables = [x for x in tables if selected_tables_dict[x] == 1]
    table_replacer = ", ".join(tables)

    select_flag = False
    from_flag = False
    limit_flag = None
    previous_selected_column = None

    i = 0
    while i < len(tokens):
        print(i, tokens[i])
        if not select_flag and tokens[i][1] == 'select':
            select_flag = True
        if not from_flag and tokens[i][1] == 'from':
            from_flag = True
        if tokens[i][0] == 'distinct':
            unique_val = ' DISTINCT'

        if tokens[i][1] == 'function':
            if tokens[i][0] == "COUNT":
                if not select_flag:
                    selectors.append('COUNT(*)')
                    select_flag = True
                else:
                    pass
            elif tokens[i][0] in ["avg", "sum"]:
                if not select_flag:
                    if tokens[i+1][1] == 'column':
                        selectors.append(f'{tokens[i][0].upper()}({tokens[i+1][0]})')
                        i+=1
                    select_flag = True
                else:
                    part_condition = f'{tokens[i][0].upper()}'
                    if tokens[i+1][1] == 'column':
                        part_condition += f'({tokens[i+1][0]})'
                        i+=1
                    else:
                        part_condition += f'({previous_selected_column})'
                    if tokens[i+1][1] == 'operator':
                        part_condition += f' {tokens[i+1][0]}'
                        i+=1
                    if tokens[i+1][1] == 'number':
                        part_condition += f' {tokens[i+1][0]}'
                    else:
                        part_condition += f" '{tokens[i+1][0]}'"
                    i+=1
                    query += f' AND {part_condition}' if query[-5:] != 'WHERE' else f' {part_condition}'
                    
            elif tokens[i][0] in ["min","max"]:
                if tokens[i+1][1] == 'column':
                    selectors.append(f'{tokens[i][0].upper()}({tokens[i+1][0]})')
                    i+=1
                select_flag = True

        elif tokens[i][0] == 'limit':
            if tokens[i+1][1] == 'number':
                limit_flag = tokens[i+1][0]
            else:
                limit_flag = 1

        elif select_flag and not from_flag:
            if tokens[i][0] == '*':
                selectors.append('*')
            elif tokens[i][1] == 'column':
                selectors = [x for x in selectors if x != '*']
                selectors.append(tokens[i][0])
            
            

        elif from_flag:
            
            if tokens[i][1] in ['from', 'table']:
                pass
            elif tokens[i][1] == 'keyword' and tokens[i][0] == 'WHERE':
                query += " WHERE"
            elif tokens[i][1] == 'keyword' and tokens[i][0] == 'ORDER BY':
                query += " ORDER BY " + f"{tables[0]}.{info['columns'][tables[0]][0]}"
            elif tokens[i][1] == 'value':
                col_name = tokens[i][2]
                if tokens[i-1][1] == 'table':
                    col_name = f"{tokens[i-1][0]}.{col_name}"
                    previous_selected_column = col_name
                elif i+1 < len(tokens) and tokens[i+1][1] == 'table':
                    col_name = f"{tokens[i+1][0]}.{col_name}"
                    previous_selected_column = col_name
                else:
                    col_name = previous_selected_column

                query += f" {col_name} = '{tokens[i][0]}'"

            else:
                if tokens[i][1] == 'operator':
                    if tokens[i-1][1] != 'column':
                        query += f" {previous_selected_column[0]}"
                    else:
                        previous_selected_column = tokens[i-1]
                query += " "+tokens[i][0]
        
        i+=1
    
    if limit_flag is not None:
        query += f" LIMIT {limit_flag}"

    if len(selectors) == 0:
        selectors.append('*')
    selector_replacer = ", ".join(selectors)

    if query.split('WHERE')[1] == '':
        query = query.split('WHERE')[0].strip()
    return query.format(unique=unique_val,selectors=selector_replacer, tables=table_replacer)

final_query(tokens, selected_tables_dict)

0 ('select', 'select')
1 ('students.name', 'column')
2 ('from', 'from')
3 ('students', 'table')
4 ('where', 'keyword')
5 ('marks.10th', 'column')
6 ('>', 'operator')
7 ('85', 'number')
8 ('marks.12th', 'column')
9 ('>', 'operator')
10 ('80', 'number')
11 ('and', 'keyword')
12 ('marks.aggregate', 'column')
13 ('>', 'operator')
14 ('65', 'number')
15 ('order by', 'keyword')
16 ('desc', 'option')


'SELECT students.name FROM students, marks WHERE where marks.10th > 85 marks.12th > 80 and marks.aggregate > 65 order by desc'