In [8]:
from moz_sql_parser import parse

# Define the predicate expression
predicate = '(nation.n_regionkey < 2 or nation.n_name like "%s") and region.r_regionkey <> 1'

# Embed the predicate into a dummy SELECT statement
sql_query = f"SELECT count(*) FROM dummy_table WHERE {predicate}"

# Parse the SQL query
parsed = parse(sql_query)

# Extract the WHERE clause
where_clause = parsed.get('where', {})


In [9]:
where_clause

{'and': [{'or': [{'lt': ['nation.n_regionkey', 2]},
    {'like': ['nation.n_name', '%s']}]},
  {'neq': ['region.r_regionkey', 1]}]}

In [6]:
import torch
from torch_geometric.data import HeteroData

# Initialize node lists and edge lists
operation_nodes = []
literal_nodes = []
numeral_nodes = []
column_nodes = []
edges = []

# Node ID counter
node_id = 0

# Mapping from node to its unique ID
node_map = {}

def traverse(parsed_dict, parent_id=None):
    global node_id
    current_id = node_id
    node_map[id(parsed_dict)] = current_id
    node_id += 1

    # Categorize the node based on its keys
    if 'and' in parsed_dict:
        operation = 'AND'
        operation_nodes.append({'id': current_id, 'op': operation})
        # The value is a list of conditions
        conditions = parsed_dict['and']
        for condition in conditions:
            child_id = traverse(condition, current_id)
            edges.append((current_id, child_id))
    elif 'or' in parsed_dict:
        operation = 'OR'
        operation_nodes.append({'id': current_id, 'op': operation})
        conditions = parsed_dict['or']
        for condition in conditions:
            child_id = traverse(condition, current_id)
            edges.append((current_id, child_id))
    elif 'not' in parsed_dict:
        operation = 'NOT'
        operation_nodes.append({'id': current_id, 'op': operation})
        condition = parsed_dict['not']
        child_id = traverse(condition, current_id)
        edges.append((current_id, child_id))
    else:
        # It's a comparison operation
        # Identify the operation type
        if 'eq' in parsed_dict:
            op = '='
            left, right = parsed_dict['eq']
        elif 'neq' in parsed_dict:
            op = '<>'
            left, right = parsed_dict['neq']
        elif 'gt' in parsed_dict:
            op = '>'
            left, right = parsed_dict['gt']
        elif 'gte' in parsed_dict:
            op = '>='
            left, right = parsed_dict['gte']
        elif 'lt' in parsed_dict:
            op = '<'
            left, right = parsed_dict['lt']
        elif 'lte' in parsed_dict:
            op = '<='
            left, right = parsed_dict['lte']
        elif 'like' in parsed_dict:
            op = 'LIKE'
            left, right = parsed_dict['like']
        else:
            op = 'UNKNOWN'

        operation_nodes.append({'id': current_id, 'op': op})

        # Process left operand
        if isinstance(left, dict):
            left_id = traverse(left, current_id)
        else:
            left_id = add_leaf_node(left, current_id)

        # Process right operand
        if isinstance(right, dict):
            right_id = traverse(right, current_id)
        else:
            right_id = add_leaf_node(right, current_id)

        # Establish edges
        edges.append((current_id, left_id))
        edges.append((current_id, right_id))

    return current_id

def add_leaf_node(value, parent_id):
    global node_id
    current_id = node_id
    node_map[id(value)] = current_id
    node_id += 1

    if isinstance(value, dict):
        # Handle nested expressions if any
        child_id = traverse(value, parent_id)
        edges.append((current_id, child_id))
    elif isinstance(value, (int, float)):
        numeral_nodes.append({'id': current_id, 'value': float(value)})
    elif isinstance(value, str):
        # Check if it's a column or a string literal
        if value.startswith('%') or value.startswith('"') or '%' in value or '_' in value:
            literal_nodes.append({'id': current_id, 'value': value.strip('"')})
        else:
            column_nodes.append({'id': current_id, 'name': value})
    else:
        # Handle other types if necessary
        pass

    # Establish edge from parent
    if parent_id is not None:
        edges.append((parent_id, current_id))

    return current_id

# Start traversal from the WHERE clause
traverse(where_clause)


0