In [1]:
import logging
import sqlparse
from sqlparse.sql import Token, TokenList, IdentifierList, Identifier, Function, Comparison
from sqlparse.tokens import DML, Whitespace, Keyword, Name, Literal

from typing import List, Dict, Union, Optional
from functools import lru_cache



In [2]:
# Setting up basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


In [3]:
sql_queries = [
    "UPDATE table SET column1 = 10",
    "SELECT column FROM table",
    """SELECT table.column1, table2.column2 FROM table JOIN (SELECT column2 FROM newtable) table2 ON table.column3 = table2.column3""",
    """
    SELECT u.id, 
           (SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id) as order_count,
           (SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id) as average_payment
    FROM users u 
    WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
    AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));"""
    ,
    """SELECT u.id
    FROM users (select * from (select * from table))""",
    "ALTER TABLE table_name ADD column_name",
    """
    SELECT * FROM (
        SELECT col1 FROM table1
        UNION
        SELECT col2 FROM table2
    ) AS subquery
    """,
    """
    WITH cte AS (
        SELECT col1 FROM table1
        UNION
        SELECT col2 FROM table2
    )
    SELECT * FROM cte
    """,
    "SELECT *    FROM    my_table WHERE   id = 1"
    
    ]


In [4]:

class QueryNormalizer:
    def normalize_query(self, query: str) -> str:
        """
        Normalizes the SQL query by formatting.

        :param query: Raw SQL query string.
        :return: Normalized SQL query string.
        """
        return sqlparse.format(query, reindent=True, keyword_case='upper', strip_whitespace=True) if query else ''


In [5]:
class QueryParser:
    def parse_query(self, query: str) -> List[Token]:
        """
        Parses the SQL query and creates a parse tree.

        :param query: SQL query string.
        :return: List of Tokens representing the parse tree.
        """
        try:
            parsed = sqlparse.parse(query)
            if not parsed:
                logger.error("Failed to parse the query")
                raise ValueError("Failed to parse the query")

            return parsed[0].tokens
        except Exception as e:
            logger.exception(f"Error while parsing the query: {e}")
            raise


In [6]:
class QueryAnalyzer:
    def is_read_only(self, tokens: List[Token]) -> bool:
        """
        Checks if the given tokens represent a read-only (SELECT) SQL query.

        :param tokens: List of Tokens representing the SQL query.
        :return: Boolean indicating if the query is read-only.
        """
        modifying_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"]
        flat_tokens = [token.normalized for token in tokens if not isinstance(token, TokenList)]
        return not any(keyword in flat_tokens for keyword in modifying_keywords)


In [7]:
class SqlQuery:
    def __init__(self, query: Optional[str] = None, normalizer: QueryNormalizer = QueryNormalizer(),
                 parser: QueryParser = QueryParser(), analyzer: QueryAnalyzer = QueryAnalyzer()):
        """
        Initialize the SqlQuery class with an optional SQL query.

        :param query: Raw SQL query string.
        :param normalizer: Instance of QueryNormalizer.
        :param parser: Instance of QueryParser.
        :param analyzer: Instance of QueryAnalyzer.
        """
        self.raw_query = query
        self.normalized_query = normalizer.normalize_query(query) if query else None
        self.tree: Optional[List[Token]] = None
        self.parser = parser
        self.analyzer = analyzer

    def set_query(self, query: str, normalize: bool = True):
        """
        Sets a new SQL query and optionally normalizes it.

        :param query: Raw SQL query string.
        :param normalize: Boolean flag to normalize the query.
        """
        self.raw_query = query
        self.normalized_query = self.normalizer.normalize_query(query) if normalize else query
        self.tree = None

    def create_tree(self) -> List[Token]:
        """
        Parses the SQL query (normalized, if available) and creates a parse tree.

        :return: List of Tokens representing the parse tree.
        """
        query_to_parse = self.normalized_query if self.normalized_query else self.raw_query
        if not query_to_parse:
            logger.error("No query set")
            raise ValueError("No query set")

        self.tree = self.parser.parse_query(query_to_parse)
        return self.tree

    def is_read_only(self) -> bool:
        """
        Checks if the set query is read-only (SELECT).

        :return: Boolean indicating if the query is read-only.
        """
        if not self.tree:
            logger.error("Parse tree is not set")
            raise ValueError("Parse tree is not set")

        return self.analyzer.is_read_only(self.tree)


In [None]:
from typing import Optional, List, Union, Dict
import sqlparse
from sqlparse.tokens import Token, TokenList
from functools import lru_cache

class QueryNormalizer:
    def normalize_query(self, query: str) -> str:
        return sqlparse.format(query, reindent=True, keyword_case='upper', strip_whitespace=True) if query else ''


class QueryParser:
    def parse_query(self, query: str) -> List[Token]:
        try:
            parsed = sqlparse.parse(query)
            if not parsed:
                logger.error("Failed to parse the query")
                raise ValueError("Failed to parse the query")
            return parsed[0].tokens
        except Exception as e:
            logger.exception(f"Error while parsing the query: {e}")
            raise


class TokenProcessor:
    def process(self, token) -> Union[Dict, str, None]:
        if token.is_group:
            children = [self.process(child) for child in token.tokens]
            return {
                "type": type(token).__name__,
                "value": token.normalized,
                "children": [c for c in children if c]
            }
        else:
            return {
                "type": token.ttype,
                "value": token.normalized
            }


class SqlQuery:
    def __init__(self, query: Optional[str] = None, normalizer: QueryNormalizer = QueryNormalizer(),
                 parser: QueryParser = QueryParser(), token_processor: TokenProcessor = TokenProcessor()):
        self.raw_query = query
        self.normalized_query = normalizer.normalize_query(query) if query else None
        self.tree: Optional[List[Token]] = None
        self.parser = parser
        self.token_processor = token_processor

    def set_query(self, query: str, normalize: bool = True):
        self.raw_query = query
        self.normalized_query = self.normalizer.normalize_query(query) if normalize else query
        self.tree = None

    @lru_cache(maxsize=None)
    def create_tree(self) -> List[Token]:
        query_to_parse = self.normalized_query if self.normalized_query else self.raw_query
        if not query_to_parse:
            logger.error("No query set")
            raise ValueError("No query set")

        try:
            parsed = self.parser.parse_query(query_to_parse)
            if not parsed:
                logger.error("Failed to parse the query")
                raise ValueError("Failed to parse the query")

            self.tree = parsed
            return self.tree
        except Exception as e:
            logger.exception(f"Error while parsing the query: {e}")
            raise

    def flatten_tree(self, tokens: Optional[List[Token]] = None) -> List[str]:
        if tokens is None:
            tokens = self.tree

        if tokens is None:
            logger.error("Parse tree is not set")
            raise ValueError("Parse tree is not set")

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def tree_to_dict(self, tokens=None) -> Dict:
        if tokens is None:
            tokens = self.tree

        return {"type": "ROOT", "children": [self.token_processor.process(token) for token in tokens if self.token_processor.process(token)]}


In [66]:

class QueryProcessor:
    def __init__(self, query: str):
        self.query = query

    def normalize_query(self) -> str:
        return sqlparse.format(self.query, reindent=True, keyword_case='upper', strip_whitespace=True)

    def parse_query(self) -> List[Token]:
        try:
            parsed = sqlparse.parse(self.query)
            if not parsed:
                raise ValueError("Failed to parse the query")
            return parsed[0].tokens
        except Exception as e:
            raise ValueError(f"Error while parsing the query: {e}")

class TokenProcessor:
    def process(self, tokens: List[Token]) -> List[Dict]:
        result = []
        for token in tokens:
            if isinstance(token, TokenList):
                children = self.process(token.tokens)
                result.append({
                    "type": type(token).__name__,
                    "value": token.normalized,
                    "children": children
                })
            else:
                result.append({
                    "type": token.ttype,
                    "value": token.normalized
                })
        return result

class SqlQuery:
    def __init__(self):
        self.raw_query = None
        self.normalized_query = None
        self.tree: Optional[List[Token]] = None

    def set_query(self, query: str, normalize: bool = True):
        self.clear_cache()
        self.raw_query = query
        if normalize:
            processor = QueryProcessor(query)
            self.normalized_query = processor.normalize_query()
            self.tree = None
        else:
            processor = QueryProcessor(query)
            self.tree = processor.parse_query()
            self.normalized_query = None
            
    @lru_cache(maxsize=None)
    def create_tree(self) -> List[Token]:
        if not self.raw_query:
            raise ValueError("No query set")

        if self.tree is None:
            processor = QueryProcessor(self.raw_query)
            self.tree = processor.parse_query()

        return self.tree

    def clear_cache(self):
        """
        Clear the LRU cache used for create_tree.
        """
        self.create_tree.cache_clear()

    def flatten_tree(self, tokens: Optional[List[Token]] = None) -> List[str]:
        if tokens is None:
            tokens = self.create_tree()

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def tree_to_dict(self, tokens=None) -> Dict:
        if tokens is None:
            tokens = self.create_tree()

        token_processor = TokenProcessor()
        return {"type": "ROOT", "children": token_processor.process(tokens)}


In [73]:

class QueryProcessor:
    def __init__(self, query: str):
        self.query = query

    def normalize_query(self) -> str:
        try:
            return sqlparse.format(self.query, reindent=True, keyword_case='upper', strip_whitespace=True)
        except Exception as e:
            raise ValueError(f"Error while normalizing the query: {e}")

    def parse_query(self) -> List[Token]:
        try:
            parsed = sqlparse.parse(self.query)
            if not parsed:
                raise ValueError("Failed to parse the query")
            return parsed[0].tokens
        except Exception as e:
            raise ValueError(f"Error while parsing the query: {e}")

class TokenProcessor:
    def process(self, tokens: List[Token]) -> List[Dict]:
        result = []
        for token in tokens:
            if isinstance(token, TokenList):
                children = self.process(token.tokens)
                result.append({
                    "type": type(token).__name__,
                    "value": token.normalized,
                    "children": children
                })
            else:
                result.append({
                    "type": token.ttype,
                    "value": token.normalized
                })
        return result

class SqlQuery:
    def __init__(self):
        self.raw_query = None
        self.normalized_query = None
        self.tree: Optional[List[Token]] = None

    def set_query(self, query: str, normalize: bool = True):
        self.clear_cache()  # Clear the cache before processing a new query.
        if not query:
            raise ValueError("Query cannot be empty")
        self.raw_query = query
        if normalize:
            processor = QueryProcessor(query)
            self.normalized_query = processor.normalize_query()
            self.tree = None
        else:
            processor = QueryProcessor(query)
            self.tree = processor.parse_query()
            self.normalized_query = None

    @lru_cache(maxsize=None)
    def create_tree(self) -> List[Token]:
        if not self.raw_query:
            raise ValueError("No query set")

        if self.tree is None:
            processor = QueryProcessor(self.raw_query)
            self.tree = processor.parse_query()

        return self.tree

    def clear_cache(self):
        """
        Clear the LRU cache used for create_tree.
        """
        self.create_tree.cache_clear()

    def flatten_tree(self, tokens: Optional[List[Token]] = None) -> List[str]:
        if tokens is None:
            tokens = self.create_tree()

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def tree_to_dict(self, tokens=None) -> Dict:
        if tokens is None:
            tokens = self.create_tree()

        token_processor = TokenProcessor()
        return {"type": "ROOT", "children": token_processor.process(tokens)}


In [101]:


class QueryProcessor:
    def __init__(self, query: str):
        self.query = query

    def normalize_query(self) -> str:
        try:
            return sqlparse.format(self.query, reindent=True, keyword_case='upper', strip_whitespace=True)
        except Exception as e:
            raise ValueError(f"Error while normalizing the query: {e}")

    def parse_query(self) -> List[Token]:
        try:
            parsed = sqlparse.parse(self.query)
            if not parsed:
                raise ValueError("Failed to parse the query")
            return parsed[0].tokens
        except Exception as e:
            raise ValueError(f"Error while parsing the query: {e}")

class TokenProcessor:
    def process(self, tokens: List[Token]) -> List[Dict]:
        result = []
        for token in tokens:
            if isinstance(token, TokenList):
                children = self.process(token.tokens)
                result.append({
                    "type": type(token).__name__,
                    "value": token.normalized,
                    "is_group": True,  # Add the is_group tag.
                    "children": children
                })
            else:
                result.append({
                    "type": token.ttype,
                    "value": token.normalized,
                    "is_group": False  # Add the is_group tag.
                })
        return result

class SqlQuery:
    def __init__(self):
        self.raw_query = None
        self.normalized_query = None
        self.tree: Optional[List[Token]] = None

    def set_query(self, query: str, normalize: bool = True):
        self.clear_cache()  # Clear the cache before processing a new query.
        if not query:
            raise ValueError("Query cannot be empty")
        self.raw_query = query
        if normalize:
            processor = QueryProcessor(query)
            self.normalized_query = processor.normalize_query()
            self.tree = None
        else:
            processor = QueryProcessor(query)
            self.tree = processor.parse_query()
            self.normalized_query = None

    @lru_cache(maxsize=None)
    def create_tree(self) -> List[Token]:
        if not self.raw_query:
            raise ValueError("No query set")

        if self.tree is None:
            processor = QueryProcessor(self.raw_query)
            self.tree = processor.parse_query()

        return self.tree

    def clear_cache(self):
        """
        Clear the LRU cache used for create_tree.
        """
        self.create_tree.cache_clear()

    def flatten_tree(self, tokens: Optional[List[Token]] = None) -> List[str]:
        if tokens is None:
            tokens = self.create_tree()

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def tree_to_dict(self, tokens=None) -> Dict:
        if tokens is None:
            tokens = self.create_tree()

        token_processor = TokenProcessor()
        return {"type": "ROOT", "children": token_processor.process(tokens)}


In [102]:
query = SqlQuery()

In [103]:
query.set_query(sql_queries[-3])

In [105]:
query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Text.Whitespace.Newline,
   'value': '\n',
   'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': Token.Keyword.DML, 'value': 'SELECT', 'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': Token.Wildcard, 'value': '*', 'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': Token.Keyword, 'value': 'FROM', 'is_group': False},
  {'type': Token.Text.Whitespace, 'value': ' ', 'is_group': False},
  {'type': 'Identifier',
   'value': '(\n        SELECT col1 FROM table1\n        UNION\n        SELECT col2 FROM table2\n    ) AS subquery',
   'is_group': True,
   'children': [{'type': 'Parenthesis',
     'value': '(

In [95]:

class QueryProcessor:
    def __init__(self, query: str):
        self.query = query

    def normalize_query(self) -> str:
        try:
            return sqlparse.format(self.query, reindent=True, keyword_case='upper', strip_whitespace=True)
        except Exception as e:
            raise ValueError(f"Error while normalizing the query: {e}")

    def parse_query(self) -> List[Token]:
        try:
            parsed = sqlparse.parse(self.query)
            if not parsed:
                raise ValueError("Failed to parse the query")
            return parsed[0].tokens
        except Exception as e:
            raise ValueError(f"Error while parsing the query: {e}")

class TokenProcessor:
    def process(self, tokens: List[Token]) -> List[Dict]:
        result = []
        for token in tokens:
            if isinstance(token, TokenList):
                children = self.process(token.tokens)
                result.append({
                    "type": type(token).__name__,
                    "value": token,
                    "is_group": True,  # Add the is_group tag.
                    "children": children
                })
            else:
                result.append({
                    "type": token.ttype,
                    "value": token,
                    "is_group": False  # Add the is_group tag.
                })
        return result



class SqlQuery:
    def __init__(self, query: Optional[str] = None):
        """
        Initialize the SqlQuery class with an optional SQL query.

        :param query: Raw SQL query string.
        """
        self._raw_query = query
        self._normalized_query = self._normalize_query(query) if query else None
        self._tree: Optional[List[Token]] = None
        self._parsed_query: Optional[Token] = None

    @property
    def raw_query(self):
        return self._raw_query

    @raw_query.setter
    def raw_query(self, query):
        self.clear_cache()
        self._raw_query = query

    @property
    def normalized_query(self):
        if not self._normalized_query:
            self._normalized_query = self._normalize_query(self._raw_query) if self._raw_query else ''
        return self._normalized_query

    def clear_cache(self):
        """
        Clear the LRU caches used for tree_dict, flatten_tree, and flatten_dict_tree.
        """
        self.tree_dict.cache_clear()
        self.flatten_tree.cache_clear()
        self.flatten_dict_tree.cache_clear()

    def _normalize_query(self, query: str) -> str:
        """
        Normalizes the SQL query by formatting.

        :param query: Raw SQL query string.
        :return: Normalized SQL query string.
        """
        return sqlparse.format(query, reindent=True, keyword_case='upper', strip_whitespace=True) if query else ''

    def set_query(self, query: str, normalize: bool = True):
        """
        Sets a new SQL query and optionally normalizes it.

        :param query: Raw SQL query string.
        :param normalize: Boolean flag to normalize the query.
        """
        self._raw_query = query
        if normalize:
            self._normalized_query = self._normalize_query(query)
        else:
            self._normalized_query = query

    @lru_cache(maxsize=None)
    def create_tree(self) -> List[Token]:
        """
        Parses the SQL query (normalized, if available) and creates a parse tree.

        :return: List of Tokens representing the parse tree.
        """
        query_to_parse = self.normalized_query if self.normalized_query else self.raw_query
        if not query_to_parse:
            raise ValueError("No query set")

        try:
            parsed = sqlparse.parse(query_to_parse)
            if not parsed:
                raise ValueError("Failed to parse the query")

            self._parsed_query = parsed[0]
            self._tree = self._parsed_query.tokens
            return self._tree
        except Exception as e:
            raise ValueError(f"Error while parsing the query: {e}")

    @property
    def tree(self):
        if not self._tree:
            self.create_tree()
        return self._tree

    def tree_to_dict(self, tokens=None) -> Dict:
        """
        Converts a list of SQL tokens to a nested dictionary representation.

        :param tokens: List of SQL tokens.
        :return: Nested dictionary representing the SQL tokens.
        """
        if tokens is None:
            tokens = self.tree

        result = {
            "type": "ROOT",
            "children": [self._process_token(token) for token in tokens if self._process_token(token)]
        }
        return result

    def flatten_tree(self, tokens: Optional[List[Token]] = None) -> List[str]:
        """
        Flattens the parse tree to a list of token strings.

        :param tokens: Optional list of tokens to flatten.
        :return: List of strings representing flattened tokens.
        """
        if tokens is None:
            tokens = self.tree

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def flatten_dict_tree(self, tokens=None) -> List[Dict]:
        """
        Flattens the parse tree to a list of dictionaries.

        :param tokens: List of SQL tokens.
        :return: List of dictionaries representing flattened tokens.
        """
        if tokens is None:
            tokens = self.tree

        flat_tokens = []
        for token in tokens:
            flat_tokens.append(self._process_token(token))
        return flat_tokens

    def _process_token(self, token) -> Union[Dict, str, None]:
        """
        Process a single SQL token into a dictionary representation.

        :param token: The SQL token to process.
        :return: A dictionary representation of the token, or None for whitespace.
        """
        if isinstance(token, TokenList):
            children = [self._process_token(child) for child in token.tokens]
            return {
                "type": type(token).__name__,
                "value": token.normalized,
                "children": [c for c in children if c]
            }
        elif not isinstance(token, Whitespace):
            return {
                "type": token.ttype,
                "value": token.normalized
            }
        return None

    @property
    @lru_cache(maxsize=None)
    def tree_dict(self):
        return self.tree_to_dict()

    @property
    @lru_cache(maxsize=None)
    def flatten_tree(self):
        return self.flatten_tree()

    @property
    @lru_cache(maxsize=None)
    def flatten_dict_tree(self):
        return self.flatten_dict_tree()



In [98]:
query = SqlQuery(sql_queries[-3])

In [97]:
query.raw_query =  sql_queries[-3]

ValueError: No query set

In [99]:
query.tree_to_dict()

TypeError: isinstance() arg 2 must be a type or tuple of types

In [81]:
query.flatten_tree()

['\n',
 ' ',
 ' ',
 ' ',
 ' ',
 'SELECT',
 ' ',
 '*',
 ' ',
 'FROM',
 ' ',
 '(',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 'SELECT',
 ' ',
 'col1',
 ' ',
 'FROM',
 ' ',
 'table1',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 'UNION',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 'SELECT',
 ' ',
 'col2',
 ' ',
 'FROM',
 ' ',
 'table2',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ',
 ')',
 ' ',
 'AS',
 ' ',
 'subquery',
 '\n',
 ' ',
 ' ',
 ' ',
 ' ']

In [57]:
query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword.DML, 'value': 'SELECT'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Wildcard, 'value': '*'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword, 'value': 'FROM'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': 'Identifier',
   'value': '(SELECT col1\n   FROM table1\n   UNION SELECT col2\n   FROM table2) AS subquery',
   'children': [{'type': 'Parenthesis',
     'value': '(SELECT col1\n   FROM table1\n   UNION SELECT col2\n   FROM table2)',
     'children': [{'type': Token.Punctuation, 'value': '('},
      {'type': Token.Keyword.DML, 'value': 'SELECT'},
      {'type': Token.Text.Whitespace, 'value': ' '},
      {'type': 'Identifier',
       'value': 'col1',
       'children': [{'type': Token.Name, 'value'

In [96]:
query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Keyword.DML, 'value': 'UPDATE'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Keyword, 'value': 'TABLE'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword, 'value': 'SET'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': 'Comparison',
   'value': 'column1 = 10',
   'children': [{'type': 'Identifier',
     'value': 'column1',
     'children': [{'type': Token.Name, 'value': 'column1'}]},
    {'type': Token.Operator.Comparison, 'value': '='},
    {'type': Token.Literal.Number.Integer, 'value': '10'}]}]}

In [72]:
for token in query.tree:
    print(token)
    try:
        print(f"has_ancestor: {token.has_ancestor}")
    except: pass
    try:
        print(f"get_sublists: {token.get_sublists}")
    except:
        pass
    print(f"ttype: {token.ttype}")
   
    try:
        print(f"tokens: {token.tokens}")
    except: pass
    # print(f"token_first: {token.token_first}")
    # print(f"token_next: {token.token_next}")
    # print(f"token_prev: {token.token_prev}")
    print(f"is_group: {token.is_group}")
    print(type(token))
    print()

UPDATE
has_ancestor: <bound method Token.has_ancestor of <DML 'UPDATE' at 0x7F7E8D9660A0>>
ttype: Token.Keyword.DML
is_group: False
<class 'sqlparse.sql.Token'>

 
has_ancestor: <bound method Token.has_ancestor of <Whitespace ' ' at 0x7F7E8D9661C0>>
ttype: Token.Text.Whitespace
is_group: False
<class 'sqlparse.sql.Token'>

TABLE
has_ancestor: <bound method Token.has_ancestor of <Keyword 'TABLE' at 0x7F7E8D957B80>>
ttype: Token.Keyword
is_group: False
<class 'sqlparse.sql.Token'>



has_ancestor: <bound method Token.has_ancestor of <Newline ' ' at 0x7F7E8D957460>>
ttype: Token.Text.Whitespace.Newline
is_group: False
<class 'sqlparse.sql.Token'>

SET
has_ancestor: <bound method Token.has_ancestor of <Keyword 'SET' at 0x7F7E8D957E80>>
ttype: Token.Keyword
is_group: False
<class 'sqlparse.sql.Token'>

 
has_ancestor: <bound method Token.has_ancestor of <Whitespace ' ' at 0x7F7E8D957EE0>>
ttype: Token.Text.Whitespace
is_group: False
<class 'sqlparse.sql.Token'>

column1 = 10
has_ancestor: <

In [61]:
query.flatten_tree()

['UPDATE', ' ', 'TABLE', '\n', 'SET', ' ', 'column1', ' ', '=', ' ', '10']

In [62]:
"UPDATE table SET column1 = 10"

'UPDATE table SET column1 = 10'

In [63]:
query.tree_to_dict() 

TypeError: isinstance() arg 2 must be a type or tuple of types

In [24]:
query = "SELECT * FROM users"
sql_query = SqlQuery(query)
sql_query.create_tree()

[<DML 'SELECT' at 0x7F7E8C11FC40>,
 <Whitespace ' ' at 0x7F7E8C11FE80>,
 <Wildcard '*' at 0x7F7E8C11FEE0>,
 <Newline ' ' at 0x7F7E8C11FF40>,
 <Keyword 'FROM' at 0x7F7E8C11FFA0>,
 <Whitespace ' ' at 0x7F7E8C11E040>,
 <Identifier 'users' at 0x7F7E8C120430>]

In [31]:
sql_query.tree_to_dict()

{'type': 'ROOT',
 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Wildcard, 'value': '*'},
  {'type': Token.Text.Whitespace.Newline, 'value': '\n'},
  {'type': Token.Keyword, 'value': 'FROM'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': 'TokenList',
   'value': 'users',
   'children': [{'type': Token.Name, 'value': 'users'}]}]}

In [36]:
sql_query.tree[-1]

<Identifier 'users' at 0x7F7E8C120430>

In [18]:
def test_tree_to_dict_select():
    query = "SELECT * FROM users"
    expected_dict = {
        "type": "ROOT",
        "children": [
            {"type": sqlparse.tokens.DML, "value": "SELECT"},
            {"type": sqlparse.tokens.Wildcard, "value": "*"},
            {"type": sqlparse.tokens.Keyword, "value": "FROM"},
            {"type": None, "value": "users"}
        ]
    }
    sql_query = SqlQuery(query)
    sql_query.create_tree()
    print(sql_query.tree_to_dict())
    print(expected_dict)
    # assert sql_query.tree_to_dict() == expected_dict, "Failed tree to dict for SELECT query"


In [19]:
test_tree_to_dict_select()

{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Text.Whitespace.Newline, 'value': '\n'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': 'TokenList', 'value': 'users', 'children': [{'type': Token.Name, 'value': 'users'}]}]}
{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': None, 'value': 'users'}]}


In [22]:
def test_tree_to_dict_select():
    query = "SELECT * FROM users"
    expected_dict = {
        "type": "ROOT",
        "children": [
            {"type": sqlparse.tokens.DML, "value": "SELECT", "children": []},
            {"type": sqlparse.tokens.Whitespace, "value": " ", "children": []},
            {"type": sqlparse.tokens.Wildcard, "value": "*", "children": []},
            {"type": sqlparse.tokens.Whitespace, "value": " ", "children": []},
            {"type": sqlparse.tokens.Keyword, "value": "FROM", "children": []},
            {"type": sqlparse.tokens.Whitespace, "value": " ", "children": []},
            {"type": None, "value": "users", "children": []}
        ]
    }
    sql_query = SqlQuery(query)
    sql_query.create_tree()
    print(sql_query.tree_to_dict())
    # print(expected_dict)
    # assert sql_query.tree_to_dict() == expected_dict, "Failed tree to dict for SELECT query"


In [23]:
test_tree_to_dict_select()

{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Text.Whitespace.Newline, 'value': '\n'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': 'TokenList', 'value': 'users', 'children': [{'type': Token.Name, 'value': 'users'}]}]}


In [106]:
# test_cases.py

test_cases = [
    {
        "query": "UPDATE table SET column1 = 10",
        "expected": {
            "type": "ROOT",
            "children": [
                {"type": "DML", "value": "UPDATE", "is_group": False},
                {"type": "Identifier", "value": "table", "is_group": False},
                {"type": "Keyword", "value": "SET", "is_group": False},
                {"type": "Identifier", "value": "column1", "is_group": False},
                {"type": "Comparison", "value": "= 10", "is_group": False}
            ]
        }
    },
    {
        "query": "SELECT column FROM table",
        "expected": {
            "type": "ROOT",
            "children": [
                {"type": "DML", "value": "SELECT", "is_group": False},
                {"type": "Identifier", "value": "column", "is_group": False},
                {"type": "Keyword", "value": "FROM", "is_group": False},
                {"type": "Identifier", "value": "table", "is_group": False}
            ]
        }
    },
    {
        "query": "SELECT * FROM users",
        "expected": {
            "type": "ROOT",
            "children": [
                {"type": "DML", "value": "SELECT", "is_group": False},
                {"type": "Wildcard", "value": "*", "is_group": False},
                {"type": "Keyword", "value": "FROM", "is_group": False},
                {"type": "Identifier", "value": "users", "is_group": False}
            ]
        }
    },
    {
        "query": "SELECT name, age FROM People WHERE age > 30",
        "expected": {
            "type": "ROOT",
            "children": [
                {"type": "DML", "value": "SELECT", "is_group": False},
                {"type": "IdentifierList", "value": "name, age", "is_group": True, "children": [
                    {"type": "Identifier", "value": "name", "is_group": False},
                    {"type": "Identifier", "value": "age", "is_group": False}
                ]},
                {"type": "Keyword", "value": "FROM", "is_group": False},
                {"type": "Identifier", "value": "People", "is_group": False},
                {"type": "Where", "value": "WHERE age > 30", "is_group": True, "children": [
                    {"type": "Comparison", "value": "age > 30", "is_group": False}
                ]}
            ]
        }
    },

{
    "query": """
        SELECT u.id, 
               (SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id) as order_count,
               (SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id) as average_payment
        FROM users u 
        WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
        AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));
    """,
    "expected": {
        "type": "ROOT",
        "children": [
            {"type": "DML", "value": "SELECT", "is_group": False},
            {"type": "Identifier", "value": "u.id", "is_group": False},
            {"type": "Subselect", "value": "(SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id)", "is_group": True, "children": [
                {"type": "DML", "value": "SELECT", "is_group": False},
                {"type": "Function", "value": "COUNT(*)", "is_group": False},
                {"type": "Keyword", "value": "FROM", "is_group": False},
                {"type": "Identifier", "value": "orders", "is_group": False},
                {"type": "Where", "value": "WHERE orders.user_id = u.id", "is_group": True, "children": [
                    {"type": "Comparison", "value": "orders.user_id = u.id", "is_group": False}
                ]}
            ]},
            {"type": "Subselect", "value": "(SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id)", "is_group": True, "children": [
                {"type": "DML", "value": "SELECT", "is_group": False},
                {"type": "Function", "value": "AVG(amount)", "is_group": False},
                {"type": "Keyword", "value": "FROM", "is_group": False},
                {"type": "Identifier", "value": "payments", "is_group": False},
                {"type": "Where", "value": "WHERE payments.user_id = u.id", "is_group": True, "children": [
                    {"type": "Comparison", "value": "payments.user_id = u.id", "is_group": False}
                ]}
            ]},
            {"type": "Keyword", "value": "FROM", "is_group": False},
            {"type": "Identifier", "value": "users u", "is_group": False},
            {"type": "Where", "value": "WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31' AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users))", "is_group": True, "children": [
                {"type": "Comparison", "value": "u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'", "is_group": False},
                {"type": "Boolean", "value": "AND", "is_group": False},
                {"type": "Parenthesis", "value": "(u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users))", "is_group": True, "children": [
                    {"type": "Comparison", "value": "u.status = 'active'", "is_group": False},
                    {"type": "Boolean", "value": "OR", "is_group": False},
                    {"type": "Subselect", "value": "(SELECT user_id FROM vip_users)", "is_group": True, "children": [
                        {"type": "DML", "value": "SELECT", "is_group": False},
                        {"type": "Identifier", "value": "user_id", "is_group": False},
                        {"type": "Keyword", "value": "FROM", "is_group": False},
                        {"type": "Identifier", "value": "vip_users", "is_group": False}
                    ]}
                ]}
            ]}
        ]
    }
},



    {
       "query": """
        WITH cte AS (
            SELECT col1 FROM table1
            UNION
            SELECT col2 FROM table2
        )
        SELECT * FROM cte
    """,
    "expected": {
        "type": "ROOT",
        "children": [
            {"type": "CTE", "value": "WITH", "is_group": True, "children": [
                {"type": "Identifier", "value": "cte", "is_group": False},
                {"type": "Keyword", "value": "AS", "is_group": False},
                {"type": "Subselect", "value": "(...)", "is_group": True, "children": [
                    {"type": "DML", "value": "SELECT", "is_group": False},
                    {"type": "Identifier", "value": "col1", "is_group": False},
                    {"type": "Keyword", "value": "FROM", "is_group": False},
                    {"type": "Identifier", "value": "table1", "is_group": False},
                    {"type": "SetOperation", "value": "UNION", "is_group": False},
                    {"type": "DML", "value": "SELECT", "is_group": False},
                    {"type": "Identifier", "value": "col2", "is_group": False},
                    {"type": "Keyword", "value": "FROM", "is_group": False},
                    {"type": "Identifier", "value": "table2", "is_group": False}
                ]}
            ]},
            {"type": "DML", "value": "SELECT", "is_group": False},
            {"type": "Wildcard", "value": "*", "is_group": False},
            {"type": "Keyword", "value": "FROM", "is_group": False},
            {"type": "Identifier", "value": "cte", "is_group": False}
        ]
    }
    },
    # Additional test cases can be added here
]


In [108]:
for i in test_cases:
    print(i['query'])

UPDATE table SET column1 = 10
SELECT column FROM table
SELECT * FROM users
SELECT name, age FROM People WHERE age > 30

        SELECT u.id, 
               (SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id) as order_count,
               (SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id) as average_payment
        FROM users u 
        WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
        AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));
    

        WITH cte AS (
            SELECT col1 FROM table1
            UNION
            SELECT col2 FROM table2
        )
        SELECT * FROM cte
    


UPDATE table SET column1 = 10
SELECT column FROM table
SELECT * FROM users
SELECT name, age FROM People WHERE age > 30

SELECT u.id, (SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id) as order_count, (SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id) as average_payment
FROM users u 
WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));

WITH cte AS (
    SELECT col1 FROM table1
    UNION
    SELECT col2 FROM table2
)
SELECT * FROM cte