In [50]:
import logging
import sqlparse
from sqlparse.sql import Token, TokenList, IdentifierList, Identifier, Function, Comparison
from sqlparse.tokens import DML, Whitespace, Keyword, Name, Literal

from typing import List, Dict, Union, Optional
from functools import lru_cache



In [2]:
# Setting up basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


In [3]:
sql_queries = [
    "UPDATE table SET column1 = 10",
    "SELECT column FROM table",
    """SELECT table.column1, table2.column2 FROM table JOIN (SELECT column2 FROM newtable) table2 ON table.column3 = table2.column3""",
    """
    SELECT u.id, 
           (SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id) as order_count,
           (SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id) as average_payment
    FROM users u 
    WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
    AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));"""
    ,
    """SELECT u.id
    FROM users (select * from (select * from table))""",
    "ALTER TABLE table_name ADD column_name",
    """
    SELECT * FROM (
        SELECT col1 FROM table1
        UNION
        SELECT col2 FROM table2
    ) AS subquery
    """,
    """
    WITH cte AS (
        SELECT col1 FROM table1
        UNION
        SELECT col2 FROM table2
    )
    SELECT * FROM cte
    """,
    "SELECT *    FROM    my_table WHERE   id = 1"
    
    ]


<Logger __main__ (INFO)>

In [97]:
class SqlQuery:
    def __init__(self, query: Optional[str] = None):
        """
        Initialize the SqlQuery class with an optional SQL query.

        :param query: Raw SQL query string.
        """
        self.raw_query = query
        self.normalized_query = self._normalize_query(query) if query else None
        self.tree: Optional[List[Token]] = None
        self.parsed_query: Optional[TokenList] = None

    def _normalize_query(self, query: str) -> str:
        """
        Normalizes the SQL query by formatting.

        :param query: Raw SQL query string.
        :return: Normalized SQL query string.
        """
        return sqlparse.format(query, reindent=True, keyword_case='upper', strip_whitespace=True) if query else ''

    def set_query(self, query: str, normalize: bool = True):
        """
        Sets a new SQL query and optionally normalizes it.

        :param query: Raw SQL query string.
        :param normalize: Boolean flag to normalize the query.
        """
        self.raw_query = query
        self.normalized_query = self._normalize_query(query) if normalize else query

    @lru_cache(maxsize=None)
    def create_tree(self) -> List[Token]:
        """
        Parses the SQL query (normalized, if available) and creates a parse tree.

        :return: List of Tokens representing the parse tree.
        """
        query_to_parse = self.normalized_query if self.normalized_query else self.raw_query
        if not query_to_parse:
            logger.error("No query set")
            raise ValueError("No query set")

        try:
            parsed = sqlparse.parse(query_to_parse)
            if not parsed:
                logger.error("Failed to parse the query")
                raise ValueError("Failed to parse the query")

            self.parsed_query = parsed[0]
            self.tree = self.parsed_query.tokens
            return self.tree
        except Exception as e:
            logger.exception(f"Error while parsing the query: {e}")
            raise

    def flatten_tree(self, tokens: Optional[List[Token]] = None) -> List[str]:
        """
        Flattens the parse tree to a list of token strings.

        :param tokens: Optional list of tokens to flatten.
        :return: List of strings representing flattened tokens.
        """
        if tokens is None:
            tokens = self.tree

        if tokens is None:
            logger.error("Parse tree is not set")
            raise ValueError("Parse tree is not set")

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def is_read_only(self) -> bool:
        """
        Checks if the set query is read-only (SELECT).

        :return: Boolean indicating if the query is read-only.
        """
        if not self.tree:
            logger.error("Parse tree is not set")
            raise ValueError("Parse tree is not set")

        modifying_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"]
        flat_tokens = self.flatten_tree()
        return not any(keyword in flat_tokens for keyword in modifying_keywords)

    def process_token(self, token) -> Union[Dict, str, None]:
        """
        Process a single SQL token into a dictionary representation.

        :param token: The SQL token to process.
        :return: A dictionary representation of the token, or None for whitespace.
        """
        if token.is_group:
            # children = [self.process_token(child) for child in token.tokens if child.ttype is not Whitespace]
            children = [self.process_token(child) for child in token.tokens]
            return {
                "type": type(token).__name__,
                "value": token.normalized,
                "children": [c for c in children if c]
            }
        else:
            # if token.ttype is Whitespace:
            #     return None
            return {
                "type": token.ttype,
                "value": token.normalized
            }

    def tree_to_dict(self, tokens=None) -> Dict:
        """
        Converts a list of SQL tokens to a nested dictionary representation.

        :param tokens: List of SQL tokens.
        :return: Nested dictionary representing the SQL tokens.
        """
        if tokens is None:
            tokens = self.tree

        return {"type": "ROOT", "children": [self.process_token(token) for token in tokens if self.process_token(token)]}



In [98]:
query = SqlQuery(query=sql_queries[0])

In [99]:
query.create_tree()

[<DML 'UPDATE' at 0x7F7E8DA2C340>,
 <Whitespace ' ' at 0x7F7E8DBD7FA0>,
 <Keyword 'TABLE' at 0x7F7E8DBBCC40>,
 <Newline ' ' at 0x7F7E8DAB5820>,
 <Keyword 'SET' at 0x7F7E8DAB5A00>,
 <Whitespace ' ' at 0x7F7E8DAB59A0>,
 <Comparison 'column...' at 0x7F7E8DBC6B30>]

In [100]:
query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Keyword.DML, 'value': 'UPDATE'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Keyword, 'value': 'TABLE'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword, 'value': 'SET'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': 'Comparison',
   'value': 'column1 = 10',
   'children': [{'type': 'Identifier',
     'value': 'column1',
     'children': [{'type': Token.Name, 'value': 'column1'}]},
    {'type': Token.Text.Whitespace, 'value': ' '},
    {'type': Token.Operator.Comparison, 'value': '='},
    {'type': Token.Text.Whitespace, 'value': ' '},
    {'type': Token.Literal.Number.Integer, 'value': '10'}]}]}

In [96]:
query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Keyword.DML, 'value': 'UPDATE'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Keyword, 'value': 'TABLE'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword, 'value': 'SET'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': 'Comparison',
   'value': 'column1 = 10',
   'children': [{'type': 'Identifier',
     'value': 'column1',
     'children': [{'type': Token.Name, 'value': 'column1'}]},
    {'type': Token.Operator.Comparison, 'value': '='},
    {'type': Token.Literal.Number.Integer, 'value': '10'}]}]}

In [72]:
for token in query.tree:
    print(token)
    try:
        print(f"has_ancestor: {token.has_ancestor}")
    except: pass
    try:
        print(f"get_sublists: {token.get_sublists}")
    except:
        pass
    print(f"ttype: {token.ttype}")
   
    try:
        print(f"tokens: {token.tokens}")
    except: pass
    # print(f"token_first: {token.token_first}")
    # print(f"token_next: {token.token_next}")
    # print(f"token_prev: {token.token_prev}")
    print(f"is_group: {token.is_group}")
    print(type(token))
    print()

UPDATE
has_ancestor: <bound method Token.has_ancestor of <DML 'UPDATE' at 0x7F7E8D9660A0>>
ttype: Token.Keyword.DML
is_group: False
<class 'sqlparse.sql.Token'>

 
has_ancestor: <bound method Token.has_ancestor of <Whitespace ' ' at 0x7F7E8D9661C0>>
ttype: Token.Text.Whitespace
is_group: False
<class 'sqlparse.sql.Token'>

TABLE
has_ancestor: <bound method Token.has_ancestor of <Keyword 'TABLE' at 0x7F7E8D957B80>>
ttype: Token.Keyword
is_group: False
<class 'sqlparse.sql.Token'>



has_ancestor: <bound method Token.has_ancestor of <Newline ' ' at 0x7F7E8D957460>>
ttype: Token.Text.Whitespace.Newline
is_group: False
<class 'sqlparse.sql.Token'>

SET
has_ancestor: <bound method Token.has_ancestor of <Keyword 'SET' at 0x7F7E8D957E80>>
ttype: Token.Keyword
is_group: False
<class 'sqlparse.sql.Token'>

 
has_ancestor: <bound method Token.has_ancestor of <Whitespace ' ' at 0x7F7E8D957EE0>>
ttype: Token.Text.Whitespace
is_group: False
<class 'sqlparse.sql.Token'>

column1 = 10
has_ancestor: <

In [61]:
query.flatten_tree()

['UPDATE', ' ', 'TABLE', '\n', 'SET', ' ', 'column1', ' ', '=', ' ', '10']

In [62]:
"UPDATE table SET column1 = 10"

'UPDATE table SET column1 = 10'

In [63]:
query.tree_to_dict() 

TypeError: isinstance() arg 2 must be a type or tuple of types

In [24]:
query = "SELECT * FROM users"
sql_query = SqlQuery(query)
sql_query.create_tree()

[<DML 'SELECT' at 0x7F7E8C11FC40>,
 <Whitespace ' ' at 0x7F7E8C11FE80>,
 <Wildcard '*' at 0x7F7E8C11FEE0>,
 <Newline ' ' at 0x7F7E8C11FF40>,
 <Keyword 'FROM' at 0x7F7E8C11FFA0>,
 <Whitespace ' ' at 0x7F7E8C11E040>,
 <Identifier 'users' at 0x7F7E8C120430>]

In [31]:
sql_query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Wildcard, 'value': '*'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword, 'value': 'FROM'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': 'TokenList',
   'value': 'users',
   'children': [{'type': Token.Name, 'value': 'users'}]}]}

In [36]:
sql_query.tree[-1]

<Identifier 'users' at 0x7F7E8C120430>

In [18]:
def test_tree_to_dict_select():
    query = "SELECT * FROM users"
    expected_dict = {
        "type": "ROOT",
        "children": [
            {"type": sqlparse.tokens.DML, "value": "SELECT"},
            {"type": sqlparse.tokens.Wildcard, "value": "*"},
            {"type": sqlparse.tokens.Keyword, "value": "FROM"},
            {"type": None, "value": "users"}
        ]
    }
    sql_query = SqlQuery(query)
    sql_query.create_tree()
    print(sql_query.tree_to_dict())
    print(expected_dict)
    # assert sql_query.tree_to_dict() == expected_dict, "Failed tree to dict for SELECT query"


In [19]:
test_tree_to_dict_select()

{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Text.Whitespace.Newline, 'value': '\n'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': 'TokenList', 'value': 'users', 'children': [{'type': Token.Name, 'value': 'users'}]}]}
{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': None, 'value': 'users'}]}


In [22]:
def test_tree_to_dict_select():
    query = "SELECT * FROM users"
    expected_dict = {
        "type": "ROOT",
        "children": [
            {"type": sqlparse.tokens.DML, "value": "SELECT", "children": []},
            {"type": sqlparse.tokens.Whitespace, "value": " ", "children": []},
            {"type": sqlparse.tokens.Wildcard, "value": "*", "children": []},
            {"type": sqlparse.tokens.Whitespace, "value": " ", "children": []},
            {"type": sqlparse.tokens.Keyword, "value": "FROM", "children": []},
            {"type": sqlparse.tokens.Whitespace, "value": " ", "children": []},
            {"type": None, "value": "users", "children": []}
        ]
    }
    sql_query = SqlQuery(query)
    sql_query.create_tree()
    print(sql_query.tree_to_dict())
    # print(expected_dict)
    # assert sql_query.tree_to_dict() == expected_dict, "Failed tree to dict for SELECT query"


In [23]:
test_tree_to_dict_select()

{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Text.Whitespace.Newline, 'value': '\n'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': 'TokenList', 'value': 'users', 'children': [{'type': Token.Name, 'value': 'users'}]}]}
