In [None]:
pip uninstall -e . -q

In [None]:
print('hi')

In [None]:
pip install -e . -q

In [None]:
import os
# from sql_nameguard.llm_suggest import LLMSuggester
from sql_nameguard.analyze import SQLAnalyzer
from sql_nameguard.SSCScalculator import SSCSCalculator
from sql_nameguard.lint import SQLLinter

# TODO: use this for when we do a proper llm_sugggest example
# from dotenv import load_dotenv
# load_dotenv()

### Basic Testing

test out linting functionality for now.

In [None]:
# --- Example Usage ---

complex_sql = """
WITH revenue_cte AS (
    SELECT 
        c.id, 
        sum(o.amount) as total_rev
    FROM customers c
    JOIN orders o ON c.id = o.customer_id
    GROUP BY 1
),
risky_users AS (
    SELECT 
        id 
    FROM revenue_cte r
    WHERE r.total_rev > 10000 
      AND (CASE WHEN r.total_rev > 50000 THEN 1 ELSE 0 END) = 1
)
SELECT 
    t1.id, 
    t1.total_rev
FROM revenue_cte t1
LEFT JOIN (
    SELECT user_id, count(*) as c FROM logs GROUP BY 1
) t2 ON t1.id = t2.user_id
WHERE t1.id IN (SELECT id FROM risky_users)
"""

In [None]:
calc = SSCSCalculator()
result = calc.calculate(complex_sql, log_warnings=False)
for k,v in result['sscs_scores'].items():
    if k not in ['final SELECT', 'overall']:
        print(f"CTE {k}:\nSSCS: {v['SSCS']}\nStructural: {v['Structural Score']}\nSemantic Penalty: {v['Semantic Penalty']}\n")
    else:
        print(f"{k}:\nSSCS: {v['SSCS']}\nStructural: {v['Structural Score']}\nSemantic Penalty: {v['Semantic Penalty']}\n")

In [None]:
result_2 = calc.calculate(complex_sql)

In [None]:
SQLLinter.lint_aliases(complex_sql)

### Try a hard query

Find a query that will definitely be flagged by the SSCS score.

In [None]:
hard_query = """
WITH temp_x AS (
    SELECT
        c.id AS a,
        SUM(o.amount) AS b,
        COUNT(DISTINCT o.id) AS c,
        AVG(o.amount) AS d,
        CASE
            WHEN SUM(o.amount) > 100000 THEN 1
            WHEN SUM(o.amount) BETWEEN 50000 AND 100000 THEN 2
            ELSE 3
        END AS e,
        (
            SELECT COUNT(*)
            FROM orders o2
            WHERE o2.customer_id = c.id
              AND (o2.status = 'PAID' OR o2.status = 'PENDING')
        ) AS f
    FROM customers c
    JOIN orders o ON c.id = o.customer_id
    LEFT JOIN payments p ON o.id = p.order_id
    WHERE
        o.status IN ('PAID', 'PENDING', 'FAILED')
        AND (o.created_at >= '2024-01-01' OR p.processed_at IS NOT NULL)
        AND (o.currency = 'USD' OR o.currency = 'EUR')
    GROUP BY 1
    HAVING COUNT(*) > 5
),
weird_agg AS (
    SELECT
        t.a AS user_key,
        ROW_NUMBER() OVER (PARTITION BY t.e ORDER BY t.d DESC) AS rn,
        SUM(t.b) OVER (PARTITION BY t.e) AS total_b,
        MAX(t.c) OVER () AS max_c,
        MIN(t.d) OVER (PARTITION BY t.e) AS min_d,
        CASE
            WHEN t.f > 50 THEN 1
            WHEN t.f BETWEEN 10 AND 50 THEN 2
            ELSE 3
        END AS z_flag
    FROM temp_x t
    WHERE
        (t.b > 1000 AND t.c > 3)
        OR (t.b > 5000 AND t.d > 200)
)
SELECT
    q.user_key,
    q.total_b,
    q.rn,
    q.max_c,
    q.min_d,
    CASE
        WHEN q.total_b > (
            SELECT AVG(total_b)
            FROM weird_agg wa
            WHERE wa.rn <= 10
              AND (wa.z_flag = 1 OR wa.z_flag = 2)
        )
        THEN 'HIGH'
        ELSE 'LOW'
    END AS risk_bucket
FROM (
    SELECT
        w.user_key,
        w.total_b,
        w.rn,
        w.max_c,
        w.min_d,
        COUNT(*) AS zz
    FROM weird_agg w
    JOIN logs l ON l.user_id = w.user_key
    LEFT JOIN (
        SELECT
            user_id,
            COUNT(*) AS c_log,
            MAX(created_at) AS last_log_at
        FROM logs
        WHERE event_type IN ('login', 'order', 'payment')
        GROUP BY user_id
        HAVING COUNT(*) > 2
    ) x ON x.user_id = w.user_key
    WHERE
        (w.rn <= 100 OR x.c_log > 10 OR x.last_log_at > '2024-06-01')
        AND (l.created_at >= '2024-01-01' AND l.created_at < '2025-01-01')
        AND (l.country = 'US' OR l.country = 'DE' OR l.country = 'FR')
    GROUP BY
        w.user_key,
        w.total_b,
        w.rn,
        w.max_c,
        w.min_d
    HAVING COUNT(*) > 3
) q
LEFT JOIN temp_x y ON y.a = q.user_key
WHERE
    q.rn < 50
    AND (q.zz > 5 OR y.f > 20)
ORDER BY
    q.total_b DESC,
    q.rn ASC;
"""


In [None]:
calc = SSCSCalculator()
result = calc.calculate(hard_query, log_warnings=True)

In [None]:
for k,v in result['sscs_scores'].items():
    if k not in ['final SELECT', 'overall']:
        print(f"CTE {k}:\nSSCS: {v['SSCS']}\nStructural: {v['Structural Score']}\nSemantic Penalty: {v['Semantic Penalty']}\n")
    else:
        print(f"{k}:\nSSCS: {v['SSCS']}\nStructural: {v['Structural Score']}\nSemantic Penalty: {v['Semantic Penalty']}\n")

In [None]:
SQLLinter.lint(hard_query)