In [1]:
from query_tree import SQLToken, SQLNode, SQLQueryTree

In [56]:
import sqlparse
from typing import Union, Dict

In [3]:
sql_query_3 = """
    SELECT u.country, COUNT(*) as user_count
    FROM users u
    WHERE u.age > 18
    GROUP BY u.country
    HAVING COUNT(*) > 100
    ORDER BY user_count DESC;
"""

query_tree_3 = SQLQueryTree(sql_query_3)
query_tree_3.get_depth()


Node: Token.Punctuation: ;

Children: []


2

In [5]:
sql_query = """
    SELECT u.id
    FROM users
    ;
"""

query_tree = SQLQueryTree(sql_query)
# print(query_tree)
query_tree.get_depth()


Node: Token.Punctuation: ;

Children: []


2

In [4]:
sql_query_2 = """
    SELECT u.id, 
           (SELECT COUNT(*) FROM orders WHERE orders.user_id = u.id) as order_count,
           (SELECT AVG(amount) FROM payments WHERE payments.user_id = u.id) as average_payment
    FROM users u 
    WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
    AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));
"""

query_tree_2 = SQLQueryTree(sql_query_2)
query_tree_2.root


Node: None: WHERE u.registration_date BETWEEN '2020-01-01' AND '2020-12-31'
    AND (u.status = 'active' OR u.id IN (SELECT user_id FROM vip_users));

Children: []


<SQLNode token=None, children=23>

In [7]:
sql_query = """
    SELECT u.id
    FROM users (select * from (select * from table)) 
    ;
"""

query_tree = SQLQueryTree(sql_query)
print(query_tree)
query_tree.get_depth()


  Token.Text.Whitespace.Newline: 

  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Keyword.DML: SELECT
  Token.Text.Whitespace:  
  None: u.id
  Token.Text.Whitespace.Newline: 

  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Keyword: FROM
  Token.Text.Whitespace:  
  None: users (select * from (select * from table))
  Token.Text.Whitespace:  
  Token.Text.Whitespace.Newline: 

  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Text.Whitespace:  
  Token.Punctuation: ;



2

In [6]:
import sqlparse
import logging
from typing import List, Optional, Iterator

class SQLToken:
    """
    Represents a basic SQL token.

    Attributes
    ----------
    token_type : sqlparse.tokens.TokenType
        The type of the SQL token.
    value : str
        The string value of the token.
    """
    def __init__(self, token_type: sqlparse.tokens._TokenType, value: str):
        self.token_type = token_type
        self.value = value

In [27]:
class SQLNode:
    """
    Represents a node in the SQL query tree.

    Attributes
    ----------
    token : Optional[SQLToken]
        The SQL token associated with this node. Can be None for structural nodes.
    children : List[SQLNode])
        The list of child nodes.
    """
    def __init__(self, token: Optional[SQLToken] = None):
        self.token = token
        self.children: List[SQLNode] = []

    def add_child(self, child: 'SQLNode'):
        """Add a child node to this node."""
        self.children.append(child)
        
    def to_dict(self) -> dict:
        """
        Convert the SQLNode and its subtree into a nested dictionary.

        Returns:
            dict: A dictionary representation of the SQLNode.
        """
        node_dict = {
            "token_type": str(self.token.token_type) if self.token else None,
            "value": self.token.value if self.token else None,
            "children": [child.to_dict() for child in self.children]
        }
        return node_dict

   

In [28]:
sql_query = "SELECT column FROM table"
parsed_query = sqlparse.parse(sql_query)[0]

In [32]:
sql_query = """SELECT table.column1, table2.column2 FROM table JOIN (SELECT column2 FROM newtable) table2 ON table.column3 = table2.column3"""
parsed_query = sqlparse.parse(sql_query)[0]

In [34]:
list(parsed_query)

[<DML 'SELECT' at 0x7FC0DB22D5E0>,
 <Whitespace ' ' at 0x7FC0DB273460>,
 <IdentifierList 'table....' at 0x7FC0DB26A900>,
 <Whitespace ' ' at 0x7FC0DB273B80>,
 <Keyword 'FROM' at 0x7FC0DB273BE0>,
 <Whitespace ' ' at 0x7FC0DB273C40>,
 <Keyword 'table' at 0x7FC0DB273CA0>,
 <Whitespace ' ' at 0x7FC0DB273D00>,
 <Keyword 'JOIN' at 0x7FC0DB273D60>,
 <Whitespace ' ' at 0x7FC0DB273DC0>,
 <Identifier '(SELEC...' at 0x7FC0DB26A890>,
 <Whitespace ' ' at 0x7FC0DB270280>,
 <Keyword 'ON' at 0x7FC0DB2702E0>,
 <Whitespace ' ' at 0x7FC0DB270340>,
 <Comparison 'table....' at 0x7FC0DB26A740>]

In [29]:
def _build_tree(parsed_query):
    root = SQLNode()
    for token in parsed_query.tokens:
        if isinstance(token, sqlparse.sql.Token):
            node = SQLNode(SQLToken(token.ttype, token.value))
            root.add_child(node)
        elif isinstance(token, (sqlparse.sql.Parenthesis, sqlparse.sql.Statement)):
            node = SQLNode()
            root.add_child(node)
            node.children.append(_build_tree(token))
    return root

In [35]:
_build_tree(parsed_query).to_dict()

{'token_type': None,
 'value': None,
 'children': [{'token_type': 'Token.Keyword.DML',
   'value': 'SELECT',
   'children': []},
  {'token_type': 'Token.Text.Whitespace', 'value': ' ', 'children': []},
  {'token_type': 'None',
   'value': 'table.column1, table2.column2',
   'children': []},
  {'token_type': 'Token.Text.Whitespace', 'value': ' ', 'children': []},
  {'token_type': 'Token.Keyword', 'value': 'FROM', 'children': []},
  {'token_type': 'Token.Text.Whitespace', 'value': ' ', 'children': []},
  {'token_type': 'Token.Keyword', 'value': 'table', 'children': []},
  {'token_type': 'Token.Text.Whitespace', 'value': ' ', 'children': []},
  {'token_type': 'Token.Keyword', 'value': 'JOIN', 'children': []},
  {'token_type': 'Token.Text.Whitespace', 'value': ' ', 'children': []},
  {'token_type': 'None',
   'value': '(SELECT column2 FROM newtable) table2',
   'children': []},
  {'token_type': 'Token.Text.Whitespace', 'value': ' ', 'children': []},
  {'token_type': 'Token.Keyword', 'value

In [None]:






class SQLQueryTree:
    """
    Represents the entire SQL query tree.

    Attributes
    ----------
    sql_query : str
        The SQL query string used to build the tree.
    root : SQLNode
        The root node of the SQL query tree.
    """
    def __init__(self, sql_query: str):
        self.sql_query = sql_query
        self.root: SQLNode = self._build_tree(sqlparse.parse(sql_query)[0])

    def _build_tree(self, parsed_query: sqlparse.sql.Statement) -> SQLNode:
        """
        Build the SQL query tree from a parsed SQL statement.

        Arguments
        ---------
        parsed_query : sqlparse.sql.Statement
            The parsed SQL statement.

        Returns
        -------
        SQLNode
            The root node of the SQL query tree.

        Raises
            ValueError: If the SQL query is invalid or parsing fails.
        """
        if not parsed_query:
            raise ValueError("Invalid SQL query provided")

        root = SQLNode()
        for token in parsed_query.tokens:
            if isinstance(token, sqlparse.sql.Token):
                node = SQLNode(SQLToken(token.ttype, token.value))
                root.add_child(node)
            elif isinstance(token, (sqlparse.sql.Parenthesis, sqlparse.sql.Statement)):
                node = SQLNode()
                root.add_child(node)
                node.children.append(self._build_tree(token))
        print("Node:", node)
        print("Children:", node.children)
        return root
    
    def get_depth(self) -> int:
        """
        Calculate the maximum depth of the SQL query tree.

        Returns:
            int: The maximum depth of the tree.
        """
        return self._get_depth_recursive(self.root)

    def _get_depth_recursive(self, node: SQLNode) -> int:
        """
        Helper method to recursively calculate the depth of the tree.

        Args:
            node (SQLNode): The current node being processed.
            current_depth (int): The depth of the current node.

        Returns:
            int: The depth of the tree rooted at the current node.
        """
        if not node.children:
            return 1
        return  1 + max(self._get_depth_recursive(child) for child in node.children)


In [57]:
import sqlparse
from functools import lru_cache
from sqlparse.tokens import DML, DDL
from sqlparse.sql import Token, TokenList
from typing import List


class SqlQuery:
    def __init__(self):
        self.tree: List[Token] = []
        self.parsed_query = None

    @lru_cache(maxsize=None)
    def create_tree(self, query: str) -> List[Token]:
        try:
            if not query or not isinstance(query, str):
                raise ValueError("Invalid query provided")
                
            self.parsed_query = sqlparse.parse(query)
            if not self.parsed_query:
                raise ValueError("Failed to parse the query")

            self.tree = self.parsed_query[0].tokens
            return self.tree
        except ValueError as ve:
            raise ve
        except Exception as e:
            raise ValueError(f"SQL parse error: {e}")

    def flatten_tree(self, tokens=None) -> List[str]:
        if tokens is None:
            tokens = self.tree

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def recompose_query(self, tokens=None) -> str:
        if tokens is None:
            tokens = self.tree

        query = ''
        for token in tokens:
            if isinstance(token, TokenList):
                query += "(" + self.recompose_query(token.tokens) + ")"
            else:
                query += str(token)
        
        return query

    def tree_to_dict_(self, tokens=None) -> Dict:
        if tokens is None:
            tokens = self.tree

        def process_token(token) -> Union[Dict, str]:
            if isinstance(token, TokenList):
                return {token.ttype: [process_token(t) for t in token.tokens]}
            else:
                return token.normalized

        return {None: [process_token(token) for token in tokens]}
    
    def tree_to_dict(self, tokens=None) -> Dict:
        if tokens is None:
            tokens = self.tree

        def process_token(token) -> Union[Dict, str]:
            if isinstance(token, TokenList):
                children = [process_token(t) for t in token.tokens]
                return {
                    "type": "TokenList",
                    "value": token.normalized,
                    "children": children
                }
            else:
                return {
                    "type": token.ttype,
                    "value": token.normalized
                }

        return {"type": "ROOT", "children": [process_token(token) for token in tokens]}

    def is_read_only(self, query: str) -> bool:
        try:
            if not query or not isinstance(query, str):
                raise ValueError("Invalid query provided")
                
            parsed = self.create_tree(query)
            if not parsed:
                raise ValueError("Failed to parse the query")

            modifying_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"]

            flat_tokens = self.flatten_tree(parsed)
            for keyword in modifying_keywords:
                if keyword in flat_tokens:
                    return False
        
            return True
        
        except ValueError as ve:
            raise ve
        except Exception as e:
            raise ValueError(f"SQL parse error: {e}")

# 




In [58]:
parser = SqlQuery()


In [59]:
tree  = parser.create_tree("SELECT * FROM table") 

In [61]:
parser.tree_to_dict()


{'type': 'ROOT',
 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Wildcard, 'value': '*'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Keyword, 'value': 'FROM'},
  {'type': Token.Text.Whitespace, 'value': ' '},
  {'type': Token.Keyword, 'value': 'TABLE'}]}

In [62]:
sqlparse.parse("SELECT * FROM table")[0].tokens

[<DML 'SELECT' at 0x7FC0DB3A53A0>,
 <Whitespace ' ' at 0x7FC0DB3A5700>,
 <Wildcard '*' at 0x7FC0DB3A5220>,
 <Whitespace ' ' at 0x7FC0DB38BB80>,
 <Keyword 'FROM' at 0x7FC0DB38B820>,
 <Whitespace ' ' at 0x7FC0DB38BC40>,
 <Keyword 'table' at 0x7FC0DB38BBE0>]

In [43]:
token = Token(DML, 'SELECT')

In [45]:
token.ttype

Token.Keyword.DML

In [54]:
tree[2].normalized


'column, *'

In [None]:
def test_create_tree():
    parser = SqlQuery()
    assert parser.create_tree("SELECT * FROM table") == [
        Token(DML, 'SELECT'),
        Token(DDL, '*'),
        Token(DDL, 'FROM'),
        Token(DDL, 'table')
    ]

def test_flatten_tree():
    parser = SqlQuery()
    parsed = parser.create_tree("SELECT * FROM table")
    assert parser.flatten_tree(parsed) == ['SELECT', '*', 'FROM', 'table']

def test_recompose_query():
    parser = SqlQuery()
    parsed = parser.create_tree("SELECT * FROM table")
    assert parser.recompose_query(parsed) == "SELECT * FROM table"

def test_is_read_only():
    parser = SqlQuery()
    assert parser.is_read_only("SELECT * FROM table") == True
    assert parser.is_read_only("UPDATE table SET column1 = 10") == False
    assert parser.is_read_only("DELETE FROM table WHERE condition") == False
    assert parser.is_read_only("DROP TABLE table_name") == False
    assert parser.is_read_only("ALTER TABLE table_name ADD column_name") == False

def test_complex_queries():
    parser = SqlQuery()
    query_nested = """
    SELECT * FROM (
        SELECT col1 FROM table1
        UNION
        SELECT col2 FROM table2
    ) AS subquery
    """
    query_with_with = """
    WITH cte AS (
        SELECT col1 FROM table1
        UNION
        SELECT col2 FROM table2
    )
    SELECT * FROM cte
    """

    assert parser.create_tree(query_nested) == [
        Token(DML, 'SELECT'),
        Token(DDL, '*'),
        Token(DDL, 'FROM'),
        Token(TokenList, [
            Token(DML, 'SELECT'),
            Token(DDL, 'col1'),
            Token(DDL, 'FROM'),
            Token(DDL, 'table1'),
            Token(DDL, 'UNION'),
            Token(DML, 'SELECT'),
            Token(DDL, 'col2'),
            Token(DDL, 'FROM'),
            Token(DDL, 'table2')
        ]),
        Token(DDL, 'AS'),
        Token(DDL, 'subquery')
    ]

    assert parser.create_tree(query_with_with) == [
        Token(TokenList, [
            Token(DDL, 'WITH'),
            Token(DDL, 'cte'),
            Token(TokenList, [
                Token(DML, 'SELECT'),
                Token(DDL, 'col1'),
                Token(DDL, 'FROM'),
                Token(DDL, 'table1'), Token(DDL, 'UNION'),
                Token(DML, 'SELECT'),
                Token(DDL, 'col2'),
                Token(DDL, 'FROM'),
                Token(DDL, 'table2')
            ])
        ]),
        Token(DML, 'SELECT'),
        Token(DDL, '*'),
        Token(DDL, 'FROM'),
        Token(DDL, 'cte')
    ]
    
    


# Run tests
test_create_tree()
test_flatten_tree()
test_recompose_query()
test_is_read_only()
test_complex_queries()

In [63]:
import sqlparse
from sqlparse.sql import Token, TokenList
from typing import List, Dict, Union
from functools import lru_cache

class SqlQuery:
    def __init__(self):
        self.tree: List[Token] = []
        self.parsed_query = None

    @lru_cache(maxsize=None)
    def create_tree(self, query: str) -> List[Token]:
        """
        Parses the SQL query and creates a parse tree.
        """
        if not query or not isinstance(query, str):
            raise ValueError("Invalid query provided")

        parsed = sqlparse.parse(query)
        if not parsed:
            raise ValueError("Failed to parse the query")

        self.parsed_query = parsed[0]
        self.tree = self.parsed_query.tokens
        return self.tree

    def flatten_tree(self, tokens=None) -> List[str]:
        """
        Flattens the parse tree to a list of token strings.
        """
        if tokens is None:
            tokens = self.tree

        flat_tokens = []
        for token in tokens:
            if isinstance(token, TokenList):
                flat_tokens.extend(self.flatten_tree(token.tokens))
            else:
                flat_tokens.append(token.normalized)
        return flat_tokens

    def recompose_query(self, tokens=None) -> str:
        """
        Recomposes the query from the token list.
        """
        if tokens is None:
            tokens = self.tree

        query = ''
        for token in tokens:
            if isinstance(token, TokenList):
                query += "(" + self.recompose_query(token.tokens) + ")"
            else:
                query += str(token)
        return query

    def is_read_only(self, query: str) -> bool:
        """
        Checks if the query is read-only (SELECT).
        """
        try:
            parsed = self.create_tree(query)
            if not parsed:
                raise ValueError("Failed to parse the query")

            modifying_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER"]
            flat_tokens = self.flatten_tree(parsed)
            return not any(keyword in flat_tokens for keyword in modifying_keywords)
        except ValueError as ve:
            raise ve
        except Exception as e:
            raise ValueError(f"SQL parse error: {e}")

    def tree_to_dict(self, tokens=None) -> Dict:
        """
        Converts the parse tree to a nested dictionary.
        """
        if tokens is None:
            tokens = self.tree

        def process_token(token) -> Union[Dict, str]:
            if isinstance(token, TokenList):
                children = [process_token(t) for t in token.tokens]
                return {
                    "type": "TokenList",
                    "value": token.normalized,
                    "children": children
                }
            else:
                return {
                    "type": token.ttype,
                    "value": token.normalized
                }

        return {"type": "ROOT", "children": [process_token(token) for token in tokens]}

    def standardize_query(self, query: str) -> str:
        """
        Standardizes the SQL query by formatting.
        """
        return sqlparse.format(query, reindent=True, keyword_case='upper', strip_whitespace=True)




In [64]:
sql_query_obj = SqlQuery()
raw_query = "SELECT *    FROM    my_table WHERE   id = 1"
standardized_query = sql_query_obj.standardize_query(raw_query)
print("Standardized Query:", standardized_query)

Standardized Query: SELECT *
FROM my_table
WHERE id = 1


In [66]:
def test_select_query_tree_to_dict():
    query = "SELECT * FROM users"
    sql_query = SqlQuery()
    sql_query.create_tree(query)

    expected_tree_dict = {
        "type": "ROOT",
        "children": [
            {"type": sqlparse.tokens.DML, "value": "SELECT"},
            {"type": sqlparse.tokens.Wildcard, "value": "*"},
            {"type": sqlparse.tokens.Keyword, "value": "FROM"},
            {"type": None, "value": "users"}
        ]
    }

    tree_dict = sql_query.tree_to_dict()
    print(tree_dict)
    print(expected_tree_dict)
    # assert tree_dict == expected_tree_dict, "tree_to_dict failed for SELECT query"


In [67]:
test_select_query_tree_to_dict()

{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': Token.Text.Whitespace, 'value': ' '}, {'type': 'TokenList', 'value': 'users', 'children': [{'type': Token.Name, 'value': 'users'}]}]}
{'type': 'ROOT', 'children': [{'type': Token.Keyword.DML, 'value': 'SELECT'}, {'type': Token.Wildcard, 'value': '*'}, {'type': Token.Keyword, 'value': 'FROM'}, {'type': None, 'value': 'users'}]}
