In [None]:
from collections import defaultdict
from typing import Any, Dict, Union, Iterable

import sqlglot
from sqlglot import exp


SqlInput = Union[str, exp.Expression, Iterable[exp.Expression]]


def _ensure_expression(sql_or_expr: SqlInput) -> exp.Expression:
    """
    Normalize input to a single sqlglot Expression.

    - If it's already an Expression, return it.
    - If it's a string, parse_one.
    - If it's an iterable of Expressions (from sqlglot.parse),
      just take the first.
    """
    if isinstance(sql_or_expr, exp.Expression):
        return sql_or_expr

    if isinstance(sql_or_expr, str):
        return sqlglot.parse_one(sql_or_expr)

    if isinstance(sql_or_expr, (list, tuple)):
        if not sql_or_expr:
            raise ValueError("Empty list/tuple of expressions passed to _ensure_expression")
        first = sql_or_expr[0]
        if not isinstance(first, exp.Expression):
            raise TypeError(f"Expected Expression in iterable, got {type(first)}")
        return first

    raise TypeError(f"Unsupported type for SQL input: {type(sql_or_expr)}")


def _max_subquery_depth(node: exp.Expression, depth: int = 0) -> int:
    """
    Compute max nesting depth of Subquery nodes.

    In some sqlglot versions, iter_expressions() yields (arg_name, expr)
    tuples; in others it yields just expr. We normalize that here.
    """
    max_depth = depth

    for child in node.iter_expressions():
        # Handle both `child` and `(arg_name, child)` forms
        if isinstance(child, tuple):
            # usually like ("where", <Expression>) or similar
            _, child_expr = child
        else:
            child_expr = child

        if not isinstance(child_expr, exp.Expression):
            continue

        if isinstance(child_expr, exp.Subquery):
            max_depth = max(max_depth, _max_subquery_depth(child_expr, depth + 1))
        else:
            max_depth = max(max_depth, _max_subquery_depth(child_expr, depth))

    return max_depth


def sql_complexity(sql_or_expr: SqlInput) -> Dict[str, Any]:
    """
    Compute a cyclomatic-style complexity score for SQL using sqlglot.

    Complexity starts at 1 and increases with:
      - joins
      - subqueries
      - CTEs
      - boolean operators (AND/OR)
      - CASE expressions
      - set operations (UNION/INTERSECT/EXCEPT)
      - additional nesting depth of subqueries
    """
    tree = _ensure_expression(sql_or_expr)
    counts: Dict[str, int] = defaultdict(int)

    # Structural counts via tree.walk()
    for node in tree.walk():
        # Joins
        if isinstance(node, exp.Join):
            counts["joins"] += 1

        # Set operations: UNION / INTERSECT / EXCEPT
        if isinstance(node, (exp.Union, exp.Intersect, exp.Except)):
            counts["set_ops"] += 1

        # CTEs: WITH ... AS (...)
        if isinstance(node, exp.CTE):
            counts["ctes"] += 1

        # CASE expressions
        if isinstance(node, exp.Case):
            counts["cases"] += 1

        # Boolean operators
        if isinstance(node, (exp.And, exp.Or)):
            counts["bool_ops"] += 1

    # Subqueries
    for _ in tree.find_all(exp.Subquery):
        counts["subqueries"] += 1

    # Max subquery depth
    max_subquery_depth = _max_subquery_depth(tree)
    counts["subquery_max_depth"] = max_subquery_depth

    # --- Compute cyclomatic-style score ---
    score = 1
    score += counts["joins"]
    score += counts["subqueries"]
    score += counts["ctes"]
    score += counts["bool_ops"]
    score += counts["cases"]
    score += counts["set_ops"]
    if max_subquery_depth > 0:
        score += max_subquery_depth - 1

    return {
        "complexity": score,
        "breakdown": dict(counts),
    }


if __name__ == "__main__":
    examples = [
        "SELECT * FROM orders",
        """
        SELECT c.customer_id, c.name, SUM(o.total) AS total_spent
        FROM customers c
        JOIN orders o ON o.customer_id = c.customer_id
        WHERE o.order_date >= '2023-01-01'
          AND (o.status = 'PAID' OR o.status = 'SHIPPED')
        GROUP BY c.customer_id, c.name
        """,
        """
        WITH high_value_orders AS (
            SELECT o.order_id, o.customer_id, o.total
            FROM orders o
            WHERE o.total > 100
        ),
        customer_totals AS (
            SELECT c.customer_id, SUM(h.total) AS total_spent
            FROM customers c
            JOIN high_value_orders h ON h.customer_id = c.customer_id
            GROUP BY c.customer_id
        )
        SELECT *
        FROM customer_totals ct
        WHERE ct.total_spent > (
            SELECT AVG(total_spent) FROM customer_totals
        )
        """
    ]

    for i, q in enumerate(examples, 1):
        result = sql_complexity(q)
        print(f"Query {i} complexity: {result['complexity']}")
        print("  breakdown:", result["breakdown"])
        print("-" * 40)


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Union, Iterable

import sqlglot
from sqlglot import exp


SqlInput = Union[str, exp.Expression, Iterable[exp.Expression]]


@dataclass
class ExpressionComplexity:
    name: str                      # CTE name or "<main>"
    score: int
    metrics: Dict[str, int]


def _maybe_parse(sql_or_expr: SqlInput) -> exp.Expression:
    """
    Normalize input to a single sqlglot Expression.
    - If it's already an Expression, return it.
    - If it's a string, parse_one.
    - If it's a list/tuple of Expressions (from sqlglot.parse), take the first.
    """
    if isinstance(sql_or_expr, exp.Expression):
        return sql_or_expr

    if isinstance(sql_or_expr, str):
        return sqlglot.parse_one(sql_or_expr)

    if isinstance(sql_or_expr, (list, tuple)):
        if not sql_or_expr:
            raise ValueError("Empty list/tuple of expressions")
        first = sql_or_expr[0]
        if not isinstance(first, exp.Expression):
            raise TypeError(f"Expected Expression in iterable, got {type(first)}")
        return first

    raise TypeError(f"Unsupported SQL input type: {type(sql_or_expr)}")


def _extract_query_expressions(root: exp.Expression) -> List[tuple[str, exp.Expression]]:
    """
    Treat CTEs like separate 'functions':
    - each CTE: (cte_name, cte_body_expression)
    - final query: ("<main>", main_expression)
    """
    results: List[tuple[str, exp.Expression]] = []

    if isinstance(root, exp.With):
        # CTEs
        # Depending on sqlglot version, CTEs are in root.args["expressions"]
        ctes = root.args.get("expressions", [])  # list of exp.CTE
        for cte in ctes:
            if not isinstance(cte, exp.CTE):
                continue
            name = cte.alias_or_name or "<anonymous_cte>"
            body = cte.this  # the SELECT / set operation inside the CTE
            if isinstance(body, exp.Expression):
                results.append((name, body))

        # main query body after WITH
        if isinstance(root.this, exp.Expression):
            results.append(("<main>", root.this))
    else:
        # no WITH: just a main query
        results.append(("<main>", root))

    return results


def _measure_expression(expr: exp.Expression) -> Dict[str, int]:
    """
    Compute raw metrics for a single expression (CTE body or main query).
    """
    metrics: Dict[str, int] = {
        "num_joins": 0,
        "num_tables": 0,
        "num_agg_funcs": 0,
        "num_window_funcs": 0,
        "num_group_keys": 0,
        "num_subqueries": 0,
        "num_set_ops": 0,
        "num_select_cols": 0,
    }

    # JOINs and Tables
    for _ in expr.find_all(exp.Join):
        metrics["num_joins"] += 1
    for _ in expr.find_all(exp.Table):
        metrics["num_tables"] += 1

    # Aggregates & window functions
    # AggFunc covers SUM/COUNT/etc; Window is OVER(...)
    for _ in expr.find_all(exp.AggFunc):
        metrics["num_agg_funcs"] += 1
    for _ in expr.find_all(exp.Window):
        metrics["num_window_funcs"] += 1

    # GROUP BY keys: exp.Group with .expressions
    for group in expr.find_all(exp.Group):
        metrics["num_group_keys"] += len(group.expressions)

    # Subqueries
    for _ in expr.find_all(exp.Subquery):
        metrics["num_subqueries"] += 1

    # Set operations
    for _ in expr.find_all(exp.Union):
        metrics["num_set_ops"] += 1
    for _ in expr.find_all(exp.Intersect):
        metrics["num_set_ops"] += 1
    for _ in expr.find_all(exp.Except):
        metrics["num_set_ops"] += 1

    # Number of select columns (not used in score by default, but nice to log)
    if isinstance(expr, exp.Select):
        metrics["num_select_cols"] = len(expr.selects)

    return metrics


def _score_expression(m: Dict[str, int]) -> int:
    """
    Turn raw metrics into a cyclomatic-style complexity score.
    Feel free to tweak weights.
    """
    score = 1
    score += m["num_joins"]
    score += m["num_agg_funcs"]
    score += 2 * m["num_window_funcs"]
    score += max(0, m["num_group_keys"] - 1)       # group by more than 1 key
    score += 2 * m["num_subqueries"]
    score += 2 * m["num_set_ops"]

    # joins + aggregations / windows / group by = "granularity awareness" bump
    if m["num_joins"] and (m["num_agg_funcs"] or m["num_window_funcs"] or m["num_group_keys"]):
        score += 1

    return score


def sql_complexity_by_expression(sql_or_expr: SqlInput) -> List[ExpressionComplexity]:
    """
    Top-level API:
      - parse SQL (if needed)
      - split into (CTEs + main query)
      - compute complexity per expression

    Returns a list of ExpressionComplexity, one per CTE and for the main query.
    """
    root = _maybe_parse(sql_or_expr)
    expressions = _extract_query_expressions(root)

    results: List[ExpressionComplexity] = []
    for name, expr in expressions:
        metrics = _measure_expression(expr)
        score = _score_expression(metrics)
        results.append(ExpressionComplexity(name=name, score=score, metrics=metrics))

    return results


# Example usage
if __name__ == "__main__":
    sql = """
    WITH high_value_orders AS (
        SELECT o.order_id, o.customer_id, o.total
        FROM orders o
        WHERE o.total > 100
    ),
    customer_totals AS (
        SELECT c.customer_id,
               SUM(h.total) AS total_spent,
               AVG(h.total) AS avg_spent
        FROM customers c
        JOIN high_value_orders h
          ON h.customer_id = c.customer_id
        GROUP BY c.customer_id
    )
    SELECT customer_id,
           total_spent,
           avg_spent,
           RANK() OVER (ORDER BY total_spent DESC) AS spending_rank
    FROM customer_totals
    WHERE total_spent > (
        SELECT AVG(total_spent) FROM customer_totals
    )
    """

    for ec in sql_complexity_by_expression(sql):
        print(f"Expression {ec.name!r} complexity score: {ec.score}")
        print("  metrics:", ec.metrics)
        if ec.score > 10:
            print("  ðŸ”¶ High complexity (like McCabe > 10)")
        print("-" * 60)
