In [None]:
import sqlglot
from sqlglot import exp

class SSCSCalculator:
    def __init__(self):
        # Configuration for weights
        self.weights = {
            exp.Join: 1,
            exp.Where: 1,
            exp.Group: 1,
            exp.Having: 1,
            exp.Order: 1,
            exp.Case: 2,           # Branching logic = higher load
            exp.Window: 2,         # Window functions are complex
            exp.Connector: 1,      # AND / OR
            exp.Subquery: 1        # Base penalty for existence of subquery
        }
        
        # Configuration for Semantic Penalty
        self.semantic_weight = 0.5  # Alpha in the formula
        self.min_alias_length = 3
        self.generic_aliases = {'temp', 'data', 't', 'x', 'val', 'obj', 'row'}

    def calculate(self, sql: str):
        """
        Parses SQL and returns the SSCS score along with a detailed breakdown.
        """
        try:
            parsed = sqlglot.parse_one(sql)
        except Exception as e:
            return {"error": f"Parse Error: {e}"}

        # 1. Isolate CTEs and Main Query
        ctes = []
        main_query = parsed

        # If there is a WITH clause, extract CTEs
        if parsed.find(exp.CTE):
            # We treat CTEs as independent "functions" for complexity
            # Note: sqlglot stores CTEs in the 'with' arg of the main expression
            ctes = parsed.find_all(exp.CTE)
                # We analyze the main query as if the CTEs are just tables
                # (The complexity of defining the CTE is handled separately)


        
        
        # 2. Calculate Structural Complexity (C_struct)
        # Sum of CTE complexities + Main Query complexity
        struct_score = 0
        component_scores = []

        # Analyze CTEs (Depth starts at 0 for each, promoting modularity)
        for cte in ctes:
            cte_score = self._compute_structural_score(cte.this, depth=0)
            struct_score += cte_score
            component_scores.append(f"CTE '{cte.alias}': {cte_score}")

        # Analyze Main Query (Depth starts at 0)
        # We explicitly exclude the WITH clause from traversal to avoid double counting
        main_score = self._compute_structural_score(main_query, depth=0, exclude_node=exp.With)
        struct_score += main_score
        component_scores.append(f"Main Query: {main_score}")

        # 3. Calculate Semantic Penalty (P_sem)
        # We look at all aliases across the entire parsed tree globally
        semantic_penalty, alias_stats = self._compute_semantic_penalty(parsed)

        # 4. Final Formula: SSCS = C_struct * (1 + P_sem)
        final_sscs = struct_score * (1 + semantic_penalty)

        return {
            "sscs_score": round(final_sscs, 2),
            "structural_score": struct_score,
            "semantic_penalty": round(semantic_penalty, 2),
            "breakdown": component_scores,
            "alias_analysis": alias_stats
        }

    def _compute_structural_score(self, node, depth, exclude_node=None):
        """
        Recursive visitor to calculate complexity weights based on AST nodes.
        Increases depth penalty for nested subqueries.
        """
        score = 0
        
        # If this node is the one we want to exclude (e.g. the CTE definitions block), stop recursion
        if exclude_node and isinstance(node, exclude_node):
            return 0

        # Apply Weight if node type is in our config
        if type(node) in self.weights:
            base_weight = self.weights[type(node)]
            # Formula: Weight + Depth Penalty
            # We add depth to the weight. Deeper logic is heavier.
            score += base_weight + (0.5 * depth)

        # Check for nesting triggers
        # If we enter a Subquery (SELECT inside FROM/WHERE), increment depth
        next_depth = depth
        if isinstance(node, exp.Subquery):
            next_depth += 1
        
        # Recursively visit children
        # sqlglot's args.values() gives us lists of children or single children
        for child_list in node.args.values():
            if isinstance(child_list, list):
                for child in child_list:
                    if isinstance(child, exp.Expression):
                        score += self._compute_structural_score(child, next_depth, exclude_node)
            elif isinstance(child_list, exp.Expression):
                score += self._compute_structural_score(child_list, next_depth, exclude_node)
                
        return score

    def _compute_semantic_penalty(self, root_node):
        """
        Scans the entire tree for aliases and calculates quality ratio.
        """
        aliases = []
        
        # 1. Capture Explicit Aliases (SELECT x AS y)
        for alias_node in root_node.find_all(exp.Alias):
            aliases.append(alias_node.alias)
            
        # 2. Capture Table Aliases (FROM table AS t)
        for table_node in root_node.find_all(exp.Table):
            if table_node.alias:
                aliases.append(table_node.alias)

        if not aliases:
            return 0.0, {"total": 0, "bad": []}

        bad_aliases = []
        for a in aliases:
            is_bad = False
            # Criteria 1: Too short
            if len(a) < self.min_alias_length:
                is_bad = True
            # Criteria 2: Generic
            elif a.lower() in self.generic_aliases:
                is_bad = True
            
            if is_bad:
                bad_aliases.append(a)

        bad_ratio = len(bad_aliases) / len(aliases)
        penalty = self.semantic_weight * bad_ratio
        
        return penalty, {
            "total": len(aliases), 
            "bad_count": len(bad_aliases), 
            "bad_examples": bad_aliases[:5] # Show first 5
        }

In [3]:
# --- Example Usage ---

complex_sql = """
WITH revenue_cte AS (
    SELECT 
        c.id, 
        sum(o.amount) as total_rev
    FROM customers c
    JOIN orders o ON c.id = o.customer_id
    GROUP BY 1
),
risky_users AS (
    SELECT 
        id 
    FROM revenue_cte r
    WHERE r.total_rev > 10000 
      AND (CASE WHEN r.total_rev > 50000 THEN 1 ELSE 0 END) = 1
)
SELECT 
    t1.id, 
    t1.total_rev
FROM revenue_cte t1
LEFT JOIN (
    SELECT user_id, count(*) as c FROM logs GROUP BY 1
) t2 ON t1.id = t2.user_id
WHERE t1.id IN (SELECT id FROM risky_users)
"""

calc = SSCSCalculator()
result = calc.calculate(complex_sql)

print(f"SSCS Score: {result['sscs_score']}")
print("-" * 30)
print(f"Structural Score: {result['structural_score']}")
print(f"Semantic Penalty: {result['semantic_penalty']} (Based on {result['alias_analysis']['bad_count']} bad aliases)")
print("\nBreakdown:")
for item in result['breakdown']:
    print(f" - {item}")
print("\nBad Aliases Found:", result['alias_analysis']['bad_examples'])

SSCS Score: 13.46
------------------------------
Structural Score: 9.5
Semantic Penalty: 0.42 (Based on 5 bad aliases)

Breakdown:
 - CTE 'revenue_cte': 2.0
 - CTE 'risky_users': 3.0
 - Main Query: 4.5

Bad Aliases Found: ['c', 't1', 'c', 'o', 'r']


## Do it again but calculate SSCS for CTEs instead of just structural score

In [4]:
import sqlglot
from sqlglot import exp

class SSCSCalculator:
    def __init__(self):
        # Configuration for weights
        self.weights = {
            exp.Join: 1,
            exp.Where: 1,
            exp.Group: 1,
            exp.Having: 1,
            exp.Order: 1,
            exp.Case: 2,           # Branching logic = higher load
            exp.Window: 2,         # Window functions are complex
            exp.Connector: 1,      # AND / OR
            exp.Subquery: 1        # Base penalty for existence of subquery
        }
        
        # Configuration for Semantic Penalty
        self.semantic_weight = 0.5  # Max penalty (50% increase)
        self.min_alias_length = 3
        self.generic_aliases = {'temp', 'data', 't', 'x', 'val', 'obj', 'row', 'a', 'b', 'c', 't1', 't2'}

    def calculate(self, sql: str):
        try:
            parsed = sqlglot.parse_one(sql)
        except Exception as e:
            return {"error": f"Parse Error: {e}"}

        components = []

        # 1. Analyze CTEs individually
        # We find all CTE definitions. 
        # Note: We assume standard CTEs. Recursive CTEs are treated as normal CTEs here.
        for cte in parsed.find_all(exp.CTE):
            # The 'this' of a CTE is the query inside it (the definition)
            cte_name = cte.alias
            cte_expression = cte.this
            
            comp_result = self._analyze_component(cte_expression, name=f"CTE: {cte_name}")
            components.append(comp_result)

        # 2. Analyze Main Query
        # We need to analyze the main query BUT exclude the WITH clause itself 
        # so we don't double-count the CTEs inside the Main Query score.
        # We create a deep copy or just traverse carefully. 
        # Easier approach: Transform the tree to remove the WITH clause temporarily for analysis.
        main_query_node = parsed.copy()
        if main_query_node.find(exp.With):
             main_query_node.find(exp.With).pop()
             
        comp_result = self._analyze_component(main_query_node, name="Main SELECT")
        components.append(comp_result)

        return {
            "components": components,
            "max_sscs": max(c['sscs'] for c in components) if components else 0
        }

    def _analyze_component(self, node, name):
        """
        Calculates SSCS for a single isolated component (CTE or Main Query).
        """
        # A. Structural Score
        struct_score = self._compute_structural_score(node, depth=0)
        
        # B. Semantic Penalty (Local to this component)
        semantic_penalty, alias_details = self._compute_semantic_penalty(node)
        
        # C. Final Calculation
        sscs = struct_score * (1 + semantic_penalty)
        
        return {
            "name": name,
            "sscs": round(sscs, 2),
            "structural": struct_score,
            "semantic_penalty": round(semantic_penalty, 2),
            "bad_aliases": alias_details['bad_examples']
        }

    def _compute_structural_score(self, node, depth):
        score = 0
        
        # Apply Weight
        if type(node) in self.weights:
            base_weight = self.weights[type(node)]
            score += base_weight + (0.5 * depth)

        # Increment depth ONLY for Subqueries (nested SELECTs), not for just any child
        next_depth = depth
        if isinstance(node, exp.Subquery):
            next_depth += 1
        
        # Recurse
        for child_list in node.args.values():
            if isinstance(child_list, list):
                for child in child_list:
                    if isinstance(child, exp.Expression):
                        score += self._compute_structural_score(child, next_depth)
            elif isinstance(child_list, exp.Expression):
                score += self._compute_structural_score(child_list, next_depth)
                
        return score

    def _compute_semantic_penalty(self, root_node):
        aliases = []
        
        # 1. Capture Explicit Aliases (SELECT x AS y)
        for alias_node in root_node.find_all(exp.Alias):
            aliases.append(alias_node.alias)
            
        # 2. Capture Table Aliases (FROM table AS t)
        for table_node in root_node.find_all(exp.Table):
            if table_node.alias:
                aliases.append(table_node.alias)

        if not aliases:
            return 0.0, {"bad_examples": []}

        bad_aliases = []
        for a in aliases:
            if len(a) < self.min_alias_length or a.lower() in self.generic_aliases:
                bad_aliases.append(a)

        bad_ratio = len(bad_aliases) / len(aliases)
        penalty = self.semantic_weight * bad_ratio
        
        return penalty, {
            "bad_examples": bad_aliases[:5]
        }

In [5]:
# --- Test ---
complex_sql = """
WITH revenue_cte AS (
    SELECT 
        c.id, 
        sum(o.amount) as total_rev
    FROM customers c
    JOIN orders o ON c.id = o.customer_id
    GROUP BY 1
),
risky_users AS (
    SELECT 
        id 
    FROM revenue_cte r
    WHERE r.total_rev > 10000 
      AND (CASE WHEN r.total_rev > 50000 THEN 1 ELSE 0 END) = 1
)
SELECT 
    t1.id, 
    t1.total_rev
FROM revenue_cte t1
LEFT JOIN (
    SELECT user_id, count(*) as c FROM logs GROUP BY 1
) t2 ON t1.id = t2.user_id
WHERE t1.id IN (SELECT id FROM risky_users)
"""

calc = SSCSCalculator()
res = calc.calculate(complex_sql)

for comp in res['components']:
    print(f"{comp['name']} | SSCS: {comp['sscs']} (Struct: {comp['structural']}, Pen: {comp['semantic_penalty']})")
    if comp['bad_aliases']:
        print(f"  Bad Aliases: {comp['bad_aliases']}")

CTE: revenue_cte | SSCS: 2.67 (Struct: 2.0, Pen: 0.33)
  Bad Aliases: ['c', 'o']
CTE: risky_users | SSCS: 4.5 (Struct: 3.0, Pen: 0.5)
  Bad Aliases: ['r']
Main SELECT | SSCS: 6.75 (Struct: 4.5, Pen: 0.5)
  Bad Aliases: ['c', 't1']
