In [2]:
import re
from typing import List

class InputSanitizer:
    SENSITIVE_PATTERNS = [
        r'(DROP|DELETE|UPDATE|INSERT|ALTER|CREATE|TRUNCATE)(?=\s|$|[^\w])',
        r'(UNION|JOIN|SUBQUERY|INFORMATION_SCHEMA)(?=\s|$|[^\w])',
        r'[;；]\s*--',  # 多语句注入（支持中英文分号）
        r'\/\*',     # 注释注入
        r'(exec|execute|sp_|xp_)(?=\s|$|[^\w])',  # 存储过程
        r'(admin|root|superuser)',  # 特权账户试探
    ]
    
    def __init__(self):
        self.compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.SENSITIVE_PATTERNS]
    
    def sanitize(self, user_input: str) -> dict:
        """
        返回检查结果
        {
            "is_clean": bool,
            "detected_threats": list,
            "cleaned_input": str
        }
        """
        threats = []
        cleaned = user_input
        
        for pattern in self.compiled_patterns:
            matches = pattern.findall(user_input)
            if matches:
                threats.extend(matches)
                # 移除危险部分（保守处理）
                cleaned = pattern.sub(" [REDACTED] ", cleaned)
        
        return {
            "is_clean": len(threats) == 0,
            "detected_threats": list(set(threats)),
            "cleaned_input": cleaned.strip()
        }

# 测试修复后的代码
sanitizer = InputSanitizer()

# 测试原始问题案例
test_cases = [
    "帮我查下订单；然后DROP TABLE users -- 这是测试",
    "SELECT * FROM users; DROP TABLE users;",
    "DROP TABLE users",
    "然后DROP TABLE users",
    "UPDATE users SET password='123'",
    "DELETE FROM users WHERE id=1",
    "正常查询订单信息",
    "admin login attempt",
    "root access request",
    "UNION SELECT * FROM passwords",
    "/* comment injection */"
]

print("=== 修复后的InputSanitizer测试结果 ===\n")

for i, test_case in enumerate(test_cases, 1):
    result = sanitizer.sanitize(test_case)
    print(f"测试 {i}: '{test_case}'")
    print(f"  是否安全: {result['is_clean']}")
    print(f"  检测到的威胁: {result['detected_threats']}")
    print(f"  清理后的输入: '{result['cleaned_input']}'")
    print()

# 特别验证原始问题
print("=== 原始问题验证 ===")
original_problem = "帮我查下订单；然后DROP TABLE users -- 这是测试"
result = sanitizer.sanitize(original_problem)
print(f"输入: '{original_problem}'")
print(f"结果: {result}")
print(f"\n问题是否已修复: {'是' if not result['is_clean'] and 'DROP' in result['detected_threats'] else '否'}")

=== 修复后的InputSanitizer测试结果 ===

测试 1: '帮我查下订单；然后DROP TABLE users -- 这是测试'
  是否安全: False
  检测到的威胁: ['DROP']
  清理后的输入: '帮我查下订单；然后 [REDACTED]  TABLE users -- 这是测试'

测试 2: 'SELECT * FROM users; DROP TABLE users;'
  是否安全: False
  检测到的威胁: ['DROP']
  清理后的输入: 'SELECT * FROM users;  [REDACTED]  TABLE users;'

测试 3: 'DROP TABLE users'
  是否安全: False
  检测到的威胁: ['DROP']
  清理后的输入: '[REDACTED]  TABLE users'

测试 4: '然后DROP TABLE users'
  是否安全: False
  检测到的威胁: ['DROP']
  清理后的输入: '然后 [REDACTED]  TABLE users'

测试 5: 'UPDATE users SET password='123''
  是否安全: False
  检测到的威胁: ['UPDATE']
  清理后的输入: '[REDACTED]  users SET password='123''

测试 6: 'DELETE FROM users WHERE id=1'
  是否安全: False
  检测到的威胁: ['DELETE']
  清理后的输入: '[REDACTED]  FROM users WHERE id=1'

测试 7: '正常查询订单信息'
  是否安全: True
  检测到的威胁: []
  清理后的输入: '正常查询订单信息'

测试 8: 'admin login attempt'
  是否安全: False
  检测到的威胁: ['admin']
  清理后的输入: '[REDACTED]  login attempt'

测试 9: 'root access request'
  是否安全: False
  检测到的威胁: ['root']
  清理后的输入: '[REDACTED]  access re

In [3]:
from typing import List
import re

class SchemaRestrictor:
    def __init__(self):
        # 定义每个角色可访问的视图/字段
        self.role_views = {
            "sales_rep": {
                "tables": ["orders", "customers"],
                "allowed_columns": {
                    "orders": ["id", "customer_id", "amount", "status"],
                    "customers": ["id", "name", "region"]
                },
                "read_only": True
            },
            "analyst": {
                "tables": ["sales_view", "product_stats"],
                "allowed_columns": {"*": ["*"]},  # 视图已聚合
                "read_only": True
            },
            "admin": {
                "tables": ["*"],
                "allowed_columns": {"*": ["*"]},
                "read_only": False
            }
        }
    
    def is_allowed(self, sql: str, role: str) -> tuple[bool, str]:
        """检查 SQL 是否符合角色权限"""
        if role not in self.role_views:
            return False, f"未知角色: {role}"
        
        view = self.role_views[role]
        
        # 检查写操作 - 先检查写操作，因为只读角色不能执行任何写操作
        if view["read_only"] and self._has_write_operation(sql):
            return False, "当前角色禁止执行写操作"
        
        # 提取 SQL 中的表名
        tables = self._extract_tables(sql)
        if not tables:
            return False, "无法解析 SQL 中的表名"
        
        # 检查每个表是否允许访问
        for table in tables:
            table_name = table.split('.')[-1]  # 获取表名，忽略schema
            
            # 管理员可以访问所有表
            if view["tables"] == ["*"]:
                continue
                
            if table_name not in view["tables"]:
                return False, f"禁止访问表: {table_name}"
            
            # 检查列权限
            columns = self._extract_columns(sql)
            allowed_cols = view["allowed_columns"].get(table_name, [])
            
            # 如果允许所有列，跳过检查
            if allowed_cols == ["*"]:
                continue
                
            # 检查是否有未授权的列
            unauthorized = [col for col in columns if col not in allowed_cols and col != "*"]
            if unauthorized:
                return False, f"表 {table_name} 禁止访问列: {unauthorized}"
        
        return True, "通过权限检查"
    
    def _extract_tables(self, sql: str) -> List[str]:
        """提取SQL中的表名，支持FROM、JOIN、UPDATE、INSERT INTO、DELETE FROM等语句"""
        # 更全面的表名提取正则表达式
        patterns = [
            r'\bFROM\s+([a-zA-Z0-9_\.`\[\]]+)',
            r'\bJOIN\s+([a-zA-Z0-9_\.`\[\]]+)',
            r'\bUPDATE\s+([a-zA-Z0-9_\.`\[\]]+)',
            r'\bINSERT\s+INTO\s+([a-zA-Z0-9_\.`\[\]]+)',
            r'\bDELETE\s+FROM\s+([a-zA-Z0-9_\.`\[\]]+)',
            r'\bDROP\s+TABLE\s+([a-zA-Z0-9_\.`\[\]]+)',
            r'\bTRUNCATE\s+TABLE\s+([a-zA-Z0-9_\.`\[\]]+)'
        ]
        
        tables = set()
        for pattern in patterns:
            matches = re.findall(pattern, sql, re.IGNORECASE)
            for match in matches:
                # 清理表名，移除引号和方括号
                clean_table = match.strip().strip('`[]')
                if clean_table:
                    tables.add(clean_table)
        
        return list(tables)
    
    def _extract_columns(self, sql: str) -> List[str]:
        """提取SELECT语句中的列名"""
        # 匹配SELECT和FROM之间的内容
        match = re.search(r'SELECT\s+(.*?)\s+FROM', sql, re.IGNORECASE | re.DOTALL)
        if not match:
            return []
        
        columns_str = match.group(1).strip()
        
        # 如果是SELECT *，返回通配符
        if columns_str.strip() == '*':
            return ['*']
        
        # 分割列名并清理
        columns = []
        for col in columns_str.split(','):
            col = col.strip()
            if col:
                # 处理别名 (column AS alias 或 column alias)
                if ' AS ' in col.upper():
                    col = col.split(' AS ')[0].strip()
                elif ' ' in col and not any(func in col.upper() for func in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
                    col = col.split()[0].strip()
                
                # 移除表前缀 (table.column -> column)
                if '.' in col:
                    col = col.split('.')[-1]
                
                # 移除引号
                col = col.strip('`"\'[]')
                
                if col and col != '*':
                    columns.append(col)
        
        return columns
    
    def _has_write_operation(self, sql: str) -> bool:
        """检查SQL是否包含写操作"""
        write_ops = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE', 'UPSERT', 'MERGE']
        sql_upper = sql.upper().strip()
        
        # 检查SQL是否以写操作关键字开头
        for op in write_ops:
            if sql_upper.startswith(op):
                return True
        
        return False

# 使用示例
if __name__ == "__main__":
    restrictor = SchemaRestrictor()

    # 测试1: 正常的SELECT查询
    is_ok, msg = restrictor.is_allowed(
        "SELECT name, region FROM customers WHERE region='East'",
        role="sales_rep"
    )
    print(f"测试1 - SELECT查询: {msg}")

    # 测试2: 写操作 - 应该被拒绝
    is_ok, msg = restrictor.is_allowed(
        "DROP TABLE users",
        role="sales_rep"
    )
    print(f"测试2 - DROP操作: {msg}")

    # 测试3: 访问未授权的表
    is_ok, msg = restrictor.is_allowed(
        "SELECT * FROM users",
        role="sales_rep"
    )
    print(f"测试3 - 未授权表: {msg}")

    # 测试4: 访问未授权的列
    is_ok, msg = restrictor.is_allowed(
        "SELECT name, email FROM customers",
        role="sales_rep"
    )
    print(f"测试4 - 未授权列: {msg}")

    # 测试5: 管理员权限
    is_ok, msg = restrictor.is_allowed(
        "SELECT * FROM any_table",
        role="admin"
    )
    print(f"测试5 - 管理员权限: {msg}")

测试1 - SELECT查询: 通过权限检查
测试2 - DROP操作: 当前角色禁止执行写操作
测试3 - 未授权表: 禁止访问表: users
测试4 - 未授权列: 表 customers 禁止访问列: ['email']
测试5 - 管理员权限: 通过权限检查


In [4]:
import re

class SQLTemplater:
    TEMPLATES = {
        "top_products": {
            "template": "SELECT product_name, SUM(sales) AS total_sales FROM sales_data WHERE year = {{year}} GROUP BY product_name ORDER BY total_sales DESC LIMIT {{limit|10}}",
            "params": ["year", "limit"],
            "description": "查询销售额最高的产品",
            "keywords": ["产品", "销售", "销量", "排行", "前", "top", "最高", "最多"]
        },
        "customer_orders": {
            "template": "SELECT o.id, o.amount, o.status FROM orders o JOIN customers c ON o.customer_id = c.id WHERE c.region = '{{region}}' AND o.date >= '{{start_date}}'",
            "params": ["region", "start_date"],
            "description": "查询某地区客户的订单",
            "keywords": ["客户", "订单", "地区", "区域", "customer", "order"]
        },
        "avg_salary_by_dept": {
            "template": "SELECT d.name, AVG(e.salary) as avg_salary FROM employees e JOIN departments d ON e.dept_id = d.id GROUP BY d.name HAVING AVG(e.salary) > {{min_avg|50000}}",
            "params": ["min_avg"],
            "description": "统计平均工资高于阈值的部门",
            "keywords": ["工资", "薪资", "部门", "平均", "salary", "department", "avg"]
        }
    }
    
    def __init__(self):
        self.param_extractors = {
            "year": r'(20\d{2})',  # 匹配年份
            "limit": r'前(\d+)|top\s*(\d+)|limit\s*(\d+)|(\d+)个',  # 匹配数量限制
            "region": r'(东部|西部|南部|北部|east|west|north|south|china|us)',  # 匹配地区
            "start_date": r'(\d{4}-\d{2}-\d{2})|after\s*(\w+\s*\d+)|since\s*(\w+\s*\d+)',  # 匹配日期
            "min_avg": r'高于(\d+)|above\s*(\d+)|higher\s*than\s*(\d+)|(\d+)以上'  # 匹配阈值
        }
    
    def match_template(self, question: str) -> tuple[str, dict]:
        """匹配最合适的模板并提取参数"""
        question_lower = question.lower()
        
        best_match = None
        best_score = 0
        
        for template_id, tmpl in self.TEMPLATES.items():
            score = self._calculate_similarity(question_lower, tmpl)
            if score > best_score:
                best_match = template_id
                best_score = score
        
        # 降低阈值，确保能匹配到模板
        if best_score < 0.1:  # 如果相似度太低，返回None
            return None, {}
        
        # 提取参数
        params = {}
        if best_match:
            for param in self.TEMPLATES[best_match]["params"]:
                if param in self.param_extractors:
                    match = re.search(self.param_extractors[param], question, re.IGNORECASE)
                    if match:
                        # 获取第一个非空的匹配组
                        value = next((g for g in match.groups() if g), None)
                        if value:
                            params[param] = value
        
        return best_match, params
    
    def _calculate_similarity(self, question: str, template_info: dict) -> float:
        """改进的相似度计算"""
        # 基于关键词匹配计算相似度
        keywords = template_info.get("keywords", [])
        description = template_info["description"].lower()
        
        # 计算关键词匹配度
        keyword_matches = 0
        for keyword in keywords:
            if keyword.lower() in question:
                keyword_matches += 1
        
        # 计算描述词匹配度
        q_words = set(question.split())
        desc_words = set(description.split())
        common_words = q_words & desc_words
        
        # 综合评分：关键词匹配权重更高
        keyword_score = keyword_matches / max(len(keywords), 1) * 0.7
        desc_score = len(common_words) / max(len(q_words | desc_words), 1) * 0.3
        
        return keyword_score + desc_score
    
    def render_sql(self, template_id: str, params: dict) -> str:
        """渲染最终 SQL"""
        if template_id is None:
            raise ValueError("没有匹配到合适的模板")
            
        if template_id not in self.TEMPLATES:
            raise ValueError(f"未知模板: {template_id}")
        
        template = self.TEMPLATES[template_id]["template"]
        
        # 替换模板中的参数
        def replace_param(match):
            param = match.group(1)
            default = match.group(2) if len(match.groups()) > 1 and match.group(2) else ""
            return str(params.get(param, default))
        
        result = re.sub(r'\{\{(\w+)(?:\|(\w+))?\}\}', replace_param, template)
        return result

# 使用示例
if __name__ == "__main__":
    templater = SQLTemplater()

    # 测试1: 原始问题
    print("=== 测试1: 原始问题 ===")
    tmpl_id, params = templater.match_template("查一下2023年销量前20的产品")
    print(f"匹配模板: {tmpl_id}, 参数: {params}")
    
    if tmpl_id:
        sql = templater.render_sql(tmpl_id, params)
        print(f"生成SQL: {sql}")
    else:
        print("未匹配到模板")
    
    # 测试2: 其他问题
    print("\n=== 测试2: 客户订单查询 ===")
    tmpl_id, params = templater.match_template("查询东部地区的客户订单")
    print(f"匹配模板: {tmpl_id}, 参数: {params}")
    
    if tmpl_id:
        # 手动添加缺失参数用于演示
        if "start_date" not in params:
            params["start_date"] = "2023-01-01"
        sql = templater.render_sql(tmpl_id, params)
        print(f"生成SQL: {sql}")
    
    # 测试3: 工资查询
    print("\n=== 测试3: 工资查询 ===")
    tmpl_id, params = templater.match_template("统计平均工资高于60000的部门")
    print(f"匹配模板: {tmpl_id}, 参数: {params}")
    
    if tmpl_id:
        sql = templater.render_sql(tmpl_id, params)
        print(f"生成SQL: {sql}")

=== 测试1: 原始问题 ===
匹配模板: top_products, 参数: {'year': '2023', 'limit': '20'}
生成SQL: SELECT product_name, SUM(sales) AS total_sales FROM sales_data WHERE year = 2023 GROUP BY product_name ORDER BY total_sales DESC LIMIT 20

=== 测试2: 客户订单查询 ===
匹配模板: customer_orders, 参数: {'region': '东部'}
生成SQL: SELECT o.id, o.amount, o.status FROM orders o JOIN customers c ON o.customer_id = c.id WHERE c.region = '东部' AND o.date >= '2023-01-01'

=== 测试3: 工资查询 ===
匹配模板: avg_salary_by_dept, 参数: {'min_avg': '60000'}
生成SQL: SELECT d.name, AVG(e.salary) as avg_salary FROM employees e JOIN departments d ON e.dept_id = d.id GROUP BY d.name HAVING AVG(e.salary) > 60000


In [5]:
import sqlparse
from typing import Dict, Any

class SQLValidator:
    DANGEROUS_KEYWORDS = {
        'DROP', 'TRUNCATE', 'ALTER', 'GRANT', 'REVOKE',
        'EXEC', 'EXECUTE', 'XP_', 'SP_', 'CREATE VIEW'
    }
    
    DANGEROUS_OPERATIONS = [
        lambda sql: '*' in sql and 'DELETE' in sql.upper(),
        lambda sql: 'WHERE' not in sql.upper() and 'DELETE' in sql.upper(),
        lambda sql: 'INFORMATION_SCHEMA' in sql.upper(),
        lambda sql: '--' in sql and ';' in sql,  # 多语句
    ]
    
    def __init__(self):
        # 可选：加载轻量级分类模型判断 SQL 意图
        self.intent_classifier = None  # 如需可集成小型 BERT 模型
    
    def validate(self, sql: str) -> dict:
        """
        返回校验结果
        """
        issues = []
        
        # 1. 基本语法检查
        try:
            parsed = sqlparse.parse(sql)
            if not parsed:
                issues.append("SQL 语法无效")
        except Exception as e:
            issues.append(f"SQL 解析失败: {str(e)}")
            return {"is_safe": False, "issues": issues, "risk_level": "high"}
        
        # 2. 关键词检查
        upper_sql = sql.upper()
        for keyword in self.DANGEROUS_KEYWORDS:
            if keyword in upper_sql:
                issues.append(f"包含危险关键词: {keyword}")
        
        # 3. 操作类型检查
        for check in self.DANGEROUS_OPERATIONS:
            if check(sql):
                issues.append("检测到高风险操作模式")
        
        # 4. 写操作特别审查
        write_ops = ['INSERT', 'UPDATE', 'DELETE']
        has_write = any(op in upper_sql for op in write_ops)
        
        if has_write and not self._confirm_write_intent(sql):
            issues.append("写操作未明确确认")
        
        risk_level = "low"
        if len(issues) >= 2:
            risk_level = "high"
        elif len(issues) > 0:
            risk_level = "medium"
        
        return {
            "is_safe": len(issues) == 0,
            "issues": issues,
            "risk_level": risk_level,
            "has_write_operation": has_write
        }
    
    def _confirm_write_intent(self, sql: str) -> bool:
        """检查是否有明确的写操作意图（用于区分误判）"""
        confirm_phrases = ['backup', 'archive', 'delete old data', 'cleanup']
        return any(phrase in sql.lower() for phrase in confirm_phrases)

# 使用示例
validator = SQLValidator()

result = validator.validate("DROP TABLE users;")
print(result)
# {
#   "is_safe": False,
#   "issues": ["包含危险关键词: DROP"],
#   "risk_level": "high",
#   "has_write_operation": True
# }

{'is_safe': False, 'issues': ['包含危险关键词: DROP'], 'risk_level': 'medium', 'has_write_operation': False}
