In [17]:
!pip install sqlparse



In [16]:
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, TokenList, Where, Parenthesis
from sqlparse.tokens import Keyword, DML, DDL

class SQLNode:
    def __init__(self, token):
        self.token = token
        self.children = []
        self.parent = None

    def add_child(self, node):
        node.parent = self
        self.children.append(node)

    def get_name(self):
        if isinstance(self.token, Identifier):
            return self.token.get_real_name()
        return self.token.value

    def is_keyword(self):
        return self.token.ttype in Keyword

    def __repr__(self):
        if not self.token:
            return ''
        else:
            return f"SQLNode(token={self.token.value}, ttype={self.token.ttype})"
#             return f"SQLNode(token={self.token.value}, children={len(self.children)})"

class SQLTree:
    def __init__(self):
        self.root = SQLNode(None)
        self.current_node = self.root

    def add_node(self, node):
        self.current_node.add_child(node)

    def enter_node(self, node):
        self.add_node(node)
        self.current_node = node

    def leave_node(self):
        if self.current_node.parent:
            self.current_node = self.current_node.parent

    def traverse(self, node=None, level=0):
        if (node is None) or (node.token is None):
            node = self.root
        print("  " * level + str(node))
        for child in node.children:
            self.traverse(child, level + 1)

def parse_tokens(tokens, parent):
    previous_keyword = None

    for token in tokens:
        if token.is_whitespace:
            continue
        elif token.is_group:
            new_node = SQLNode(token)
            parent.add_child(new_node)
            parse_tokens(token.tokens, new_node)
        elif token.ttype in Keyword:
            keyword_node = SQLNode(token)
            parent.add_child(keyword_node)
            previous_keyword = keyword_node
        elif isinstance(token, IdentifierList):
            identifier_list_node = SQLNode(token)
            if previous_keyword:
                previous_keyword.add_child(identifier_list_node)
            else:
                parent.add_child(identifier_list_node)
            for identifier in token.get_identifiers():
                parse_tokens([identifier], identifier_list_node)
        elif isinstance(token, Identifier):
            identifier_node = SQLNode(token)
            if previous_keyword:
                previous_keyword.add_child(identifier_node)
                previous_keyword = None
            else:
                parent.add_child(identifier_node)
        elif isinstance(token, Where):
            where_node = SQLNode(token)
            parent.add_child(where_node)
            parse_tokens(token.tokens, where_node)
        elif isinstance(token, Parenthesis):
            parenthesis_node = SQLNode(token)
            parent.add_child(parenthesis_node)
            parse_tokens(token.tokens, parenthesis_node)
        else:
            general_node = SQLNode(token)
            if previous_keyword:
                previous_keyword.add_child(general_node)
                previous_keyword = None
            else:
                parent.add_child(general_node)

def parse_sql_query(query, statement_idx=0):
    parsed = sqlparse.parse(query)
    if not parsed:
        return None
    
    tokens = parsed[statement_idx].tokens
    tree = SQLTree()
    parse_tokens(tokens, tree.root)
    return tree

# Example usage with a complex SQL query
query = """
-- DDL: Create tables
CREATE TABLE employees (
    employee_id INT PRIMARY KEY,
    first_name VARCHAR(50),
    last_name VARCHAR(50),
    department_id INT,
    salary DECIMAL(10, 2)
);

CREATE TABLE departments (
    department_id INT PRIMARY KEY,
    department_name VARCHAR(50)
);

CREATE TABLE projects (
    project_id INT PRIMARY KEY,
    project_name VARCHAR(100),
    start_date DATE,
    end_date DATE
);

CREATE TABLE employee_projects (
    employee_id INT,
    project_id INT,
    assignment_date DATE,
    PRIMARY KEY (employee_id, project_id)
);

-- CTE: Common Table Expressions
WITH department_salaries AS (
    SELECT 
        d.department_name,
        SUM(e.salary) AS total_salary
    FROM 
        employees e
    JOIN 
        departments d ON e.department_id = d.department_id
    GROUP BY 
        d.department_name
), project_counts AS (
    SELECT 
        e.employee_id,
        COUNT(ep.project_id) AS project_count
    FROM 
        employees e
    LEFT JOIN 
        employee_projects ep ON e.employee_id = ep.employee_id
    GROUP BY 
        e.employee_id
)

-- DML: Insert and Select Statements
INSERT INTO employee_projects (employee_id, project_id, assignment_date)
VALUES (1, 1, '2023-01-01'),
       (2, 1, '2023-01-01'),
       (1, 2, '2023-02-01'),
       (3, 3, '2023-03-01');

SELECT 
    e.first_name,
    e.last_name,
    d.department_name,
    ps.total_salary,
    pc.project_count
FROM 
    employees e
JOIN 
    departments d ON e.department_id = d.department_id
LEFT JOIN 
    department_salaries ps ON d.department_name = ps.department_name
LEFT JOIN 
    project_counts pc ON e.employee_id = pc.employee_id
WHERE 
    e.salary > 50000
ORDER BY 
    e.last_name;

"""
sql_tree = parse_sql_query(query, statement_idx=5)

# Output the parsed tree structure
sql_tree.traverse()


  SQLNode(token=SELECT, ttype=Token.Keyword.DML)
  SQLNode(token=e.first_name,
    e.last_name,
    d.department_name,
    ps.total_salary,
    pc.project_count, ttype=None)
    SQLNode(token=e.first_name, ttype=None)
      SQLNode(token=e, ttype=Token.Name)
      SQLNode(token=., ttype=Token.Punctuation)
      SQLNode(token=first_name, ttype=Token.Name)
    SQLNode(token=,, ttype=Token.Punctuation)
    SQLNode(token=e.last_name, ttype=None)
      SQLNode(token=e, ttype=Token.Name)
      SQLNode(token=., ttype=Token.Punctuation)
      SQLNode(token=last_name, ttype=Token.Name)
    SQLNode(token=,, ttype=Token.Punctuation)
    SQLNode(token=d.department_name, ttype=None)
      SQLNode(token=d, ttype=Token.Name)
      SQLNode(token=., ttype=Token.Punctuation)
      SQLNode(token=department_name, ttype=Token.Name)
    SQLNode(token=,, ttype=Token.Punctuation)
    SQLNode(token=ps.total_salary, ttype=None)
      SQLNode(token=ps, ttype=Token.Name)
      SQLNode(token=., ttype=Token.Punctu