In [None]:
from tqdm import tqdm
import pandas as pd
import re, time
import numpy as np
from utils import structure_from_connection, tables_from_connection
from sqlalchemy import create_engine
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import sqlparse, sqlite3

In [None]:
query = r'''SELECT Артикул, Номенклатура, COUNT(*) AS sales_count
FROM sales
WHERE Артикул IS NOT NULL
  AND Артикул NOT LIKE 'u%'
  AND Хозяйственнаяоперация != 'Закрытие месяца'
GROUP BY Артикул, Номенклатура
ORDER BY sales_count DESC
LIMIT 10;'''

In [None]:
from sqlalchemy import text

table = pd.read_excel('2023_04_Продажи_код_артикул.xlsx')
engine = create_engine('sqlite://', echo=False)
table.to_sql(name='sales', con=engine)
conn = engine.connect()

In [None]:
structure_from_connection(conn)

In [None]:
def prepare_column_names(conn : sqlite3.Connection):
    structure = structure_from_connection(conn)
    for table in structure:
        for column in table['columns']:
            if ' ' in column:
                new_name = ''.join([char for char in column if str.isalnum(char)])
                conn.execute(text(
                    f'''ALTER TABLE {table['table_name']} RENAME COLUMN "{column}" TO {new_name}'''
                ))

In [None]:
prepare_column_names(conn)  

In [None]:
import zss
import sqlparse

def pretty_print(node, shift):
    print(shift + str(node))
    shift += '    '
    for token in node.children:
        pretty_print(token, shift + '    ')


class SqlNode:
    def __init__(self, node):
        self.children = []
        self.raw_node = node
        if type(node) == sqlparse.sql.Token or type(node) == sqlparse.sql.Identifier:
            self.label = str(node.value)
            return
        
        self.label = type(node).__name__
        for token in node.tokens:
            if token.is_whitespace:
                continue

            self.children.append(SqlNode(token))

    def __repr__(self):
        return str(type(self.raw_node)) + ' ' + self.label
    
    #@staticmethod
    def get_children(self):
        return self.children
    
    #@staticmethod
    def get_label(self):
        return self.label

def dist_comp(node1, node2):
    return int(node1 != node2)

def ratio(tree1 : SqlNode, tree2 : SqlNode):
    edit_distance = zss.simple_distance(tree1, tree2, SqlNode.get_children, SqlNode.get_label, dist_comp)

    def __tree_nodes_count(root):
        cnt = 0
        for child in root.children:
            cnt += __tree_nodes_count(child)

        cnt += 1
        return cnt
    
    max_nodes = max(__tree_nodes_count(tree1), __tree_nodes_count(tree2))
    return max(1 - edit_distance/max_nodes, 0)

In [None]:
def parse_literals(sql : str, table_structure : list[dict]):
    root = sqlparse.parse(sql)[0]
    names = []

    def __get_all_names_helper(node : sqlparse.sql.Token):
        if issubclass(type(node), sqlparse.sql.TokenList):
            for token in node.tokens:
                __get_all_names_helper(token)
        elif node.ttype != sqlparse.sql.T.Punctuation and node.ttype != sqlparse.sql.T.Whitespace:
            names.append(node.value)

    __get_all_names_helper(root)
    
    tables = set([table['table_name'] for table in table_structure])
    visited_tables = set([])
    buckets = {}
    for name in names:
        if name in tables and name not in visited_tables:
            buckets[name] = {}
            visited_tables.add(name)
        elif name not in tables:
            for table in table_structure:
                if name in table['columns']:
                    if table['table_name'] not in visited_tables:
                        buckets[table['table_name']] = {name}
                        visited_tables.add(table['table_name'])
                    else:
                        buckets[table['table_name']].add(name)
    
    return buckets

In [None]:
structure = structure_from_connection(conn)
parse_literals(query, structure)

# names = []

# def __get_all_names_helper(node : sqlparse.sql.Token):
#     if issubclass(type(node), sqlparse.sql.TokenList):
#         for token in node.tokens:
#             __get_all_names_helper(token)
#     elif node.ttype != sqlparse.sql.T.Punctuation and node.ttype != sqlparse.sql.T.Whitespace:
#         names.append(node.value)

# __get_all_names_helper(sqlparse.parse(query)[0])

#names