In [1]:
import csv
import sqlparse as sp
from sqlparse.sql import IdentifierList, Identifier, Where, Comparison
from sqlparse.tokens import Keyword, DML, Newline, Whitespace, Text, Token

In [2]:
'''Data loading functions'''

def extract_metadata():
    path = ""
    metafile = open('files/metadata.txt', 'r') 
    metalines = metafile.readlines() 
    tables_dict = {}
    tables_meta = {}
    tables_list = {}
    is_rec = False
    is_tname = False
    cur_table = ""
    for line in metalines:
        if line.startswith('<begin_table>'):
            is_rec = True
            is_tname = True
        elif line.startswith('<end_table>'):
            if_rec = False
        elif is_tname:
            cur_table = str(line).strip()
            tables_dict[cur_table] = {}
            tables_meta[cur_table] = []
            tables_list[cur_table] = []
            is_tname = False
        else:
            tables_dict[cur_table][str(line).strip()] = []
            tables_meta[cur_table].append(str(line).strip())
    return tables_meta,tables_dict,tables_list
            
def extract_csvdata_bycols(tables_dict):
    tables_data = tables_dict
    for tn,cl in tables_dict.items(): 
        with open('files/'+tn+'.csv', newline='') as table_file:
            all_data = csv.reader(table_file,delimiter=',')
#             print(all_data)
            for row in all_data:
                for c,i in zip(cl,range(len(cl))):
                    tables_data[tn][c].append(int(row[i]))
    return tables_data

def extract_csvdata_byrows(tables_list):
    tables_data = tables_list
    for tn in tables_list: 
        with open('files/'+tn+'.csv', newline='') as table_file:
            all_data = csv.reader(table_file,delimiter=',')
#             print(all_data)
            for row in all_data:
                tables_data[tn].append([int(x) for x in row])
    return tables_data
    

In [3]:
tables_meta,tables_dict,tables_list = extract_metadata()
print(tables_meta)
print(tables_list)
tables_data_bycols = extract_csvdata_bycols(tables_dict)
print(tables_data_bycols)
tables_data_byrows = extract_csvdata_byrows(tables_list)
print(tables_data_byrows)

{'table1': ['A', 'B', 'C'], 'table2': ['D', 'E']}
{'table1': [], 'table2': []}
{'table1': {'A': [922, 640, 775, -551, 922, -354, -497, 411, 922, 858, 640], 'B': [158, 773, 85, 811, 311, 646, 335, 803, 718, 731, 773], 'C': [5727, 5058, 10164, 1534, 1318, 7063, 4549, 10519, 9020, 3668, 5058]}, 'table2': {'D': [158, 775, 86, 812, 812, 640, 336, 804, 719, 922], 'E': [11191, 14421, 5117, 12262, 16116, 5403, 6309, 12262, 12262, 13021]}}
{'table1': [[922, 158, 5727], [640, 773, 5058], [775, 85, 10164], [-551, 811, 1534], [922, 311, 1318], [-354, 646, 7063], [-497, 335, 4549], [411, 803, 10519], [922, 718, 9020], [858, 731, 3668], [640, 773, 5058]], 'table2': [[158, 11191], [775, 14421], [86, 5117], [812, 12262], [812, 16116], [640, 5403], [336, 6309], [804, 12262], [719, 12262], [922, 13021]]}


In [6]:
'''SQL functions'''
invalid_msg = 'Invalid query syntax'


def remove_wspaces(parsed_sql):
    modf_parsed_sql = []
    for token in parsed_sql:
        if token.is_whitespace:
            continue
        modf_parsed_sql.append(token)
    return modf_parsed_sql


def attr_condition(cndtn):
    c_attr = {'id1':'', 'opr':'', 'id2':''}
    for token in cndtn:
        if isinstance(token,Identifier):
            if not c_attr['id1']:
                c_attr['id1'] = token.get_name()
            else:
                c_attr['id2'] = token.get_name()
        elif token.ttype is Token.Operator.Comparison:
            c_attr['opr'] += token.value
        elif token.ttype is Token.Literal.Number.Integer:
            if not c_attr['id1']:
                c_attr['id1'] = int(token.value)
            else:
                c_attr['id2'] = int(token.value)
    return c_attr
    

def process_where(where_stmnt):
    where_dict = {'andor':"", 'conditions':[]}
    modf_where_stmnt = remove_wspaces(where_stmnt)
    for token in modf_where_stmnt:
        if token.ttype is Keyword and token.value == 'where':
            continue
        if token.ttype is Keyword and token.value == 'and':
            where_dict['andor'] = 'and'
        elif token.ttype is Keyword and token.value == 'or':
            where_dict['andor'] = 'or'
        else:
            where_dict['conditions'].append(token)
    return where_dict
    
def get_aggregate_fn(token):
    aggfncs_list = ['count','max','min','sum','avg']
#     print(token,type(token))
    aggfn_dict = {}
    for af in aggfncs_list:
        if af+'(' in token:
            aggfn_dict['func'] = af
            aggfn_dict['col'] = token[token.find('(')+1:token.find(')')].upper()
    return aggfn_dict

    
def process_query(parsed_sql):
    modf_parsed_sql = remove_wspaces(parsed_sql)
    curr_token = ""
    q_columns,q_tables,q_groupby,q_conditions = [],[],[],{}
    q_aggfn = {'func':[], 'col':[]}
    q_orderby = {'col':'', 'order':None}
    q_distinct = False
    for token in modf_parsed_sql:
#         print("========",token,token.ttype,type(token))
        if token.ttype is DML and token.value == 'select':
            curr_token = 'select'
            continue
        if token.ttype is Keyword and token.value == 'from':
            if curr_token not in ['select','distinct']:
                print('from:',invalid_msg)
#                 exit(0)
            curr_token = 'from'
            continue
        if isinstance(token,Where):
            if curr_token not in ['from']:
                print('where:',invalid_msg)
#                 exit(0)
            curr_token = 'where'
            q_conditions = process_where(token)
            continue
        if token.ttype is Keyword and token.value == 'group by':
            if curr_token not in ['from','where']:
                print('group by:',invalid_msg)
#                 exit(0)
            curr_token = 'group by'
            continue
        if token.ttype is Keyword and token.value == 'order by':
            if curr_token not in ['from','where','group by']:
                print('order by:',invalid_msg)
#                 exit(0)
            curr_token = 'order by'
            continue
        if curr_token == 'select':
            if token.ttype is Keyword and token.value == 'distinct':
                if not q_columns and not q_aggfn['func']:
                    print('distinct:',invalid_msg)
#                     exit(0)
                q_distinct = True
            elif isinstance(token, IdentifierList):
                for c in token.get_identifiers():
#                     print("+++++",str(c),c.get_name())
                    aggfn = get_aggregate_fn(str(c))
                    if aggfn:
                        if q_aggfn['func']:
                            q_aggfn['func'].append(aggfn['func'])
                            q_aggfn['col'].append(aggfn['col'])
                        else:
                            q_aggfn['func'] = [aggfn['func']]
                            q_aggfn['col'] = [aggfn['col']]
                    else:
                        q_columns.append(c.get_name().upper())
            elif isinstance(token, Identifier):
                q_columns.append(token.get_name().upper())
            elif token.ttype is Token.Wildcard:
                q_columns = ['*']
            elif isinstance(token,sp.sql.Function):
                aggfn = get_aggregate_fn(str(token))
                if aggfn:
                    if q_aggfn['func']:
                        q_aggfn['func'].append(aggfn['func'])
                        q_aggfn['col'].append(aggfn['col'])
                    else:
                        q_aggfn['func'] = [aggfn['func']]
                        q_aggfn['col'] = [aggfn['col']]
            else:
                print('select:',invalid_msg)
#                 exit(0)
        elif curr_token == 'from':
            if isinstance(token, IdentifierList):
                for t in token.get_identifiers():
                    q_tables.append(t.get_name())
            elif isinstance(token, Identifier):
                q_tables.append(token.get_name())
            else:
                print('from:',invalid_msg)
#                 exit(0)
        elif curr_token == 'group by':
            if isinstance(token, IdentifierList):
                for c in token.get_identifiers():
                    q_groupby.append(c.get_name().upper())
            elif isinstance(token, Identifier):
                q_groupby.append(token.get_name().upper())
            elif token.ttype is Token.Wildcard:
                q_groupby = ['*']
            else:
                print('group by:',invalid_msg)
#                 exit(0)
        elif curr_token == 'order by':
            if isinstance(token, Identifier):
                q_orderby['col'] = token.get_name().upper()
                q_orderby['order'] = token.get_ordering()
            else:
                print('order by:',invalid_msg)
#                 exit(0)
    q_attributes = {}
    q_attributes['q_tables'] = q_tables
    q_attributes['q_cols'] = q_columns
    q_attributes['q_conditions'] = q_conditions
    q_attributes['q_groupby'] = q_groupby
    q_attributes['q_aggfn'] = q_aggfn
    q_attributes['q_distinct'] = q_distinct
    q_attributes['q_orderby'] = q_orderby
    return q_attributes

def join_tables(tables):
    join_data = tables_data_byrows[tables[0]]
#     num_rows = len(join_data)
    disp_cnames = []
    disp_cnames += tables_meta[tables[0]]
#     print('b4 lp',disp_cnames,tables_meta[tables[0]])
    for t in tables[1:]:
        temp_join = []
        for rj in join_data:
            for rt in tables_data_byrows[t]:
                temp_join.append(rj+rt)
        join_data = temp_join
        disp_cnames += tables_meta[t]
#         print(disp_cnames)
    return join_data, disp_cnames

def get_distinct(q_rows):
    return set([tuple(row) for row in q_rows])
    
def select_rows(q_rows,q_tables,q_cols):
    cols_idx = []
    allcols = []
    for t in q_tables:
        allcols += tables_meta[t]
    if q_cols[0]=='*':
        return q_rows,allcols
    for c in q_cols:
        cols_idx.append(allcols.index(c.upper()))
    sel_rows = []
    for row in q_rows:
        tmprow = []
        for i in cols_idx:
            tmprow.append(row[i])
        sel_rows.append(tmprow)
    return sel_rows,q_cols

def display(q_rows,q_tables,disp_cnames):
    print("--------OUTPUT--------")
    for c in disp_cnames:
        tab = ''
#         tab = any([c in cl for t,cl in tables_meta.items()])
        for t,cl in tables_meta.items():
            if c in cl:
                tab = t
        if not tab:
            print(c,end='\t')
        else:
            print(tab+'.'+c,end='\t')
    print()
    for row in q_rows:
        for i in range(len(disp_cnames)):
            print(row[i],end='\t\t')
        print()
    print("\nRows displayed:",len(q_rows))
    
def compare_cols(row,c_attr,fc1,fc2,xc1,xc2):
    if (c_attr['opr'] == "=" and ((fc1 and fc2 and row[xc1]==row[xc2]) or ((not fc2) and row[xc1]==c_attr['id2']))) or \
       (c_attr['opr'] == ">" and ((fc1 and fc2 and row[xc1]>row[xc2]) or ((not fc2) and row[xc1]>c_attr['id2']))) or \
       (c_attr['opr'] == "<" and ((fc1 and fc2 and row[xc1]<row[xc2]) or ((not fc2) and row[xc1]<c_attr['id2']))) or \
       (c_attr['opr'] == "!=" and ((fc1 and fc2 and row[xc1]!=row[xc2]) or ((not fc2) and row[xc1]!=c_attr['id2']))) or \
       (c_attr['opr'] == "<=" and ((fc1 and fc2 and row[xc1]<=row[xc2]) or ((not fc2) and row[xc1]<=c_attr['id2']))) or \
       (c_attr['opr'] == ">=" and ((fc1 and fc2 and row[xc1]>=row[xc2]) or ((not fc2) and row[xc1]>=c_attr['id2']))):
        return True
    return False
    
def execute_where(q_rows,q_tables,q_where):
    cnames = []
    for t in q_tables:
        cnames += tables_meta[t]
    sel_rows = [False for i in range(len(q_rows))]
    ci=0
    for cndtn in q_where['conditions']:
        cndtn = remove_wspaces(cndtn)
        c_attr = attr_condition(cndtn)
        print(c_attr)
        fc1=False
        fc2=False
        xc1=-1
        xc2=-1
        if isinstance(c_attr['id1'],str):
            fc1 = True
            xc1 = cnames.index(c_attr['id1'].upper())
        if isinstance(c_attr['id2'],str):
            fc2 = True
            xc2 = cnames.index(c_attr['id2'].upper())
        print(fc1 and fc2)
        if ci==0:
            r = 0
            for row in q_rows:
                sel_rows[r] = compare_cols(row,c_attr,fc1,fc2,xc1,xc2)
                r+=1
        else:
            r=0
            if q_where['andor']=='and':
                for row in q_rows:
                    if sel_rows[r]:
                        sel_rows[r] = compare_cols(row,c_attr,fc1,fc2,xc1,xc2)
                    r+=1
            elif q_where['andor']=='or':
                for row in q_rows:
                    if not sel_rows[r]:
                        sel_rows[r] = compare_cols(row,c_attr,fc1,fc2,xc1,xc2)
                    r+=1
        ci+=1
    r=0
    new_qrows = []
    for row in q_rows:
        if sel_rows[r]:
            new_qrows.append(row)
        r+=1
    return new_qrows

def execute_aggfn(grp_rows,q_aggfn,cnames,q_grpcols='',q_cols=''):
    grp_out = []
    disp_cnames = []
    if q_grpcols:
        for gc in q_grpcols:
            if gc in q_cols:
                gcix = cnames.index(gc.upper())
                grp_out.append(grp_rows[0][gcix])
                disp_cnames.append(gc)
    for afi in range(len(q_aggfn['func'])):
        if q_aggfn['func'][afi] == 'max':
            ci = cnames.index(q_aggfn['col'][afi].upper())
            grp_out.append([max(i) for i in zip(*grp_rows)][ci])
            disp_cnames.append("max("+q_aggfn['col'][afi]+")")
        elif q_aggfn['func'][afi] == 'min':
            ci = cnames.index(q_aggfn['col'][afi].upper())
            grp_out.append([min(i) for i in zip(*grp_rows)][ci])
            disp_cnames.append("min("+q_aggfn['col'][afi]+")")
        elif q_aggfn['func'][afi] == 'sum':
            ci = cnames.index(q_aggfn['col'][afi].upper())
            grp_out.append([sum(i) for i in zip(*grp_rows)][ci])
            disp_cnames.append("sum("+q_aggfn['col'][afi]+")")
        elif q_aggfn['func'][afi] == 'count':
            grp_out.append(len(grp_rows))
            disp_cnames.append("count("+q_aggfn['col'][afi]+")")
        elif q_aggfn['func'][afi] == 'avg':
            ci = cnames.index(q_aggfn['col'][afi].upper())
            grp_out.append(round([sum(i) for i in zip(*grp_rows)][ci]/len(grp_rows),2))
            disp_cnames.append("avg("+q_aggfn['col'][afi]+")")
#     print(grp_out)
    return [grp_out],disp_cnames
    
            
def execute_groupby(q_rows,q_grpcols,cnames,q_aggfn,q_cols):
    gcol_idx = [cnames.index(gc) for gc in q_grpcols]
#     gcol_tuples = [row[gcol_idx] for row in q_rows]
    gcol_tuples = []
    for row in q_rows:
        gcvl = []
        for gci in gcol_idx:
            gcvl.append(row[gci])
        gcol_tuples.append(tuple(gcvl))
    gcval_map = {}
    for i in range(len(gcol_tuples)):
        if gcol_tuples[i] not in gcval_map:
            gcval_map[gcol_tuples[i]] = []
        gcval_map[gcol_tuples[i]].append(i) #+=
    new_grows = []
    for x,y in gcval_map.items():
        if q_aggfn['func']:
            grp_rows = []
            for i in y:
                grp_rows.append(q_rows[i])
            grp_rows,disp_cnames = execute_aggfn(grp_rows,q_aggfn,cnames,q_grpcols,q_cols)
#             new_grows.append(grp_rows)
            new_grows += grp_rows
        else:
            new_grows.append(list(x))
            disp_cnames = q_cols
#     print(new_grows)
    return new_grows,disp_cnames

# def flatten_groups(q_grows):
#     flat_rows = []
#     for g in q_grows:
#         flat_rows += g
#     return flat_rows

def execute_orderby(q_rows,q_orderby,disp_cnames):
    ci = disp_cnames.index(q_orderby['col'])
    print("order i",ci)
    return(sorted(q_rows, key = lambda x: x[ci], reverse=(q_orderby['order']=='DESC')))   

def execute_query(q_attributes):
    q_data = []
    cnames = []
    disp_cnames = []
    flag = False
    if not q_attributes['q_tables'] or any(t not in tables_meta for t in q_attributes['q_tables']):
        print("Table does not exist")
#         exit(0)
    for t in q_attributes['q_tables']:
        cnames += tables_meta[t]
    if len(q_attributes['q_tables'])>1:
        q_data,disp_cnames = join_tables(q_attributes['q_tables'])
    else:
        q_data = tables_data_byrows[q_attributes['q_tables'][0]]
        disp_cnames = cnames
    if q_attributes['q_conditions']:
        q_data = execute_where(q_data,q_attributes['q_tables'],q_attributes['q_conditions'])
    if q_attributes['q_groupby']:
        q_data,disp_cnames = execute_groupby(q_data,q_attributes['q_groupby'],disp_cnames,q_attributes['q_aggfn'],q_attributes['q_cols'])
        flag = True
    if q_attributes['q_aggfn']['func'] and not q_attributes['q_groupby']:
        q_data,disp_cnames = execute_aggfn(q_data,q_attributes['q_aggfn'],disp_cnames)
        flag = True
    if not flag:
        q_data,disp_cnames = select_rows(q_data,q_attributes['q_tables'],q_attributes['q_cols'])
    if q_attributes['q_distinct']:
        q_data = get_distinct(q_data)
    if q_attributes['q_orderby']['col']:
        q_data = execute_orderby(q_data,q_attributes['q_orderby'],disp_cnames)
    display(q_data,q_attributes['q_tables'],disp_cnames)
    
        
    

In [8]:
# qry_input = input().strip().lower()
qry_input = "select count(*) from;".lower()
if qry_input[-1] != ';':
    print("Semicolon missing")
#     exit(0)
# frmt_qry = sp.format(qry_input,reindent=True, keyword_case='upper')
# print(frmt_qry[1])
parsed_sql = sp.parse(qry_input[:len(qry_input)-1])[0]
print(parsed_sql.tokens)
q_attributes = process_query(parsed_sql)
print(q_attributes)
execute_query(q_attributes)

[<DML 'select' at 0x7F8C4DF42FA8>, <Whitespace ' ' at 0x7F8C4DF6C0A8>, <Function 'count(...' at 0x7F8C4DF4C750>, <Whitespace ' ' at 0x7F8C4DF6C2E8>, <Keyword 'from' at 0x7F8C4DF6C3A8>]
{'q_tables': [], 'q_cols': [], 'q_conditions': {}, 'q_groupby': [], 'q_aggfn': {'func': ['count'], 'col': ['*']}, 'q_distinct': False, 'q_orderby': {'col': '', 'order': None}}
Table does not exist


IndexError: list index out of range

In [121]:
# qry_input = input().strip()
# parsed_sql = sp.parse(qry_input)[0]
# print(parsed_sql)
# print(parsed_sql.tokens)
# print(sp.split(qry_input))
# print(qry_output[0])
# print(qry_output[1][0])
print(qry_output[4][2])

NameError: name 'qry_output' is not defined

In [32]:
print(parsed_sql.tokens[1])
print(parsed_sql.flatten())
for token in parsed_sql.tokens:
    if(token.is_whitespace):
        print("w")
    print(token)

 
<generator object TokenList.flatten at 0x7f880811f728>
select
w
 
*
w
 
from
w
 
table1, table2
w
 
where a>0


In [37]:
l = [[0, 1, 2], [20, 31, 40], [51, 60, 70], [81, 91, 11]]  
l[:][1]

[20, 31, 40]

In [44]:
tables_meta

{'table1': ['A', 'B', 'C'], 'table2': ['D', 'E']}

In [69]:
round(58.78788756, 2)

58.79

In [26]:
a = {'h':9}
# [x for x in a]
'H' in a


False