From 5f9f7d71e66c0a4b157dc50a461d0ac797fe842b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E5=9C=88=E5=9C=88?= Date: Sat, 1 Apr 2023 21:32:04 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"engine=E5=A2=9E=E5=8A=A0escape=5Fstri?= =?UTF-8?q?ng=E7=94=A8=E4=BA=8E=E5=A4=84=E7=90=86=E5=AD=97=E7=AC=A6?= =?UTF-8?q?=E4=B8=B2=E5=8F=82=E6=95=B0=E8=BD=AC=E4=B9=89=20(#2107)"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit e2573b74a1017912815c8e26f6d3c843c7fe2dae. --- sql/data_dictionary.py | 8 +++----- sql/engines/__init__.py | 4 ---- sql/engines/clickhouse.py | 5 ----- sql/engines/mysql.py | 10 +++------- sql/instance.py | 15 +++++---------- sql/instance_database.py | 4 +++- sql/sql_optimize.py | 4 ++-- sql/sql_tuning.py | 2 +- sql/tests.py | 2 +- sql_api/api_instance.py | 8 +++++--- sql_api/api_workflow.py | 5 +---- 11 files changed, 24 insertions(+), 43 deletions(-) diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 4a702d8075..2ee48b8d5e 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -29,7 +29,6 @@ def table_list(request): instance_name=instance_name, db_type=db_type ) query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) data = query_engine.get_group_tables_by_db(db_name=db_name) res = {"status": 0, "data": data} except Instance.DoesNotExist: @@ -51,7 +50,6 @@ def table_info(request): db_name = request.GET.get("db_name", "") tb_name = request.GET.get("tb_name", "") db_type = request.GET.get("db_type", "") - if instance_name and db_name and tb_name: data = {} try: @@ -59,8 +57,6 @@ def table_info(request): instance_name=instance_name, db_type=db_type ) query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) - tb_name = query_engine.escape_string(tb_name) data["meta_data"] = query_engine.get_table_meta_data( db_name=db_name, tb_name=tb_name ) @@ -95,6 +91,8 @@ def export(request): """导出数据字典""" instance_name = request.GET.get("instance_name", "") db_name = request.GET.get("db_name", "") + # escape + db_name = MySQLdb.escape_string(db_name).decode("utf-8") try: instance = user_instances( @@ -106,7 +104,7 @@ def export(request): # 普通用户仅可以获取指定数据库的字典信息 if db_name: - dbs = [query_engine.escape_string(db_name)] + dbs = [db_name] # 管理员可以导出整个实例的字典信息 elif request.user.is_superuser: dbs = query_engine.get_all_databases().rows diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index 7adf3df930..a101abed13 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -86,10 +86,6 @@ def info(self): """返回引擎简介""" return "Base engine" - def escape_string(self, value: str) -> str: - """参数转义""" - return value - @property def auto_backup(self): """是否支持备份""" diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index a776d7ed07..22216c4be5 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -1,6 +1,5 @@ # -*- coding: UTF-8 -*- from clickhouse_driver import connect -from clickhouse_driver.util.escape import escape_chars_map from sql.utils.sql_utils import get_syntax_type from .models import ResultSet, ReviewResult, ReviewSet from common.utils.timer import FuncTimer @@ -50,10 +49,6 @@ def name(self): def info(self): return "ClickHouse engine" - def escape_string(self, value: str) -> str: - """字符串参数转义""" - return "'%s'" % "".join(escape_chars_map.get(c, c) for c in value) - @property def auto_backup(self): """是否支持备份""" diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index 637a604b71..ee47fd4240 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -98,10 +98,6 @@ def name(self): def info(self): return "MySQL engine" - def escape_string(self, value: str) -> str: - """字符串参数转义""" - return MySQLdb.escape_string(value).decode("utf-8") - @property def auto_backup(self): """是否支持备份""" @@ -171,7 +167,7 @@ def get_all_tables(self, db_name, **kwargs): def get_group_tables_by_db(self, db_name): # escape - db_name = self.escape_string(db_name) + db_name = MySQLdb.escape_string(db_name).decode("utf-8") data = {} sql = f"""SELECT TABLE_NAME, TABLE_COMMENT @@ -190,8 +186,8 @@ def get_group_tables_by_db(self, db_name): def get_table_meta_data(self, db_name, tb_name, **kwargs): """数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}""" # escape - db_name = self.escape_string(db_name) - tb_name = self.escape_string(tb_name) + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") sql = f"""SELECT TABLE_NAME as table_name, ENGINE as engine, diff --git a/sql/instance.py b/sql/instance.py index 34accdc3c3..c041f7035f 100644 --- a/sql/instance.py +++ b/sql/instance.py @@ -163,9 +163,6 @@ def param_edit(request): instance_id = request.POST.get("instance_id") variable_name = request.POST.get("variable_name") variable_value = request.POST.get("runtime_value") - # escape - variable_name = MySQLdb.escape_string(variable_name).decode("utf-8") - variable_value = MySQLdb.escape_string(variable_value).decode("utf-8") try: ins = Instance.objects.get(id=instance_id) @@ -323,10 +320,12 @@ def instance_resource(request): result = {"status": 0, "msg": "ok", "data": []} try: + # escape + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") + tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") + query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) - schema_name = query_engine.escape_string(schema_name) - tb_name = query_engine.escape_string(tb_name) if resource_type == "database": resource = query_engine.get_all_databases() elif resource_type == "schema" and db_name: @@ -364,14 +363,10 @@ def describe(request): db_name = request.POST.get("db_name") schema_name = request.POST.get("schema_name") tb_name = request.POST.get("tb_name") - result = {"status": 0, "msg": "ok", "data": []} try: query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) - schema_name = query_engine.escape_string(schema_name) - tb_name = query_engine.escape_string(tb_name) query_result = query_engine.describe_table( db_name, tb_name, schema_name=schema_name ) diff --git a/sql/instance_database.py b/sql/instance_database.py index 87b51fcb4a..15d1572c35 100644 --- a/sql/instance_database.py +++ b/sql/instance_database.py @@ -111,8 +111,10 @@ def create(request): except Users.DoesNotExist: return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []}) + # escape + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + engine = get_engine(instance=instance) - db_name = engine.escape_string(db_name) exec_result = engine.execute( db_name="information_schema", sql=f"create database {db_name};" ) diff --git a/sql/sql_optimize.py b/sql/sql_optimize.py index 62a9f8a80b..b147fa9c98 100644 --- a/sql/sql_optimize.py +++ b/sql/sql_optimize.py @@ -163,6 +163,8 @@ def optimize_sqltuning(request): except Instance.DoesNotExist: result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []} return HttpResponse(json.dumps(result), content_type="application/json") + # escape + db_name = MySQLdb.escape_string(db_name).decode("utf-8") sql_tunning = SqlTuning( instance_name=instance_name, db_name=db_name, sqltext=sqltext @@ -233,7 +235,6 @@ def explain(request): # 执行获取执行计划语句 query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict() result["data"] = sql_result @@ -286,7 +287,6 @@ def optimize_sqltuningadvisor(request): # 执行获取优化报告 query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict() result["data"] = sql_result diff --git a/sql/sql_tuning.py b/sql/sql_tuning.py index 4dac2a46cf..973406cba9 100644 --- a/sql/sql_tuning.py +++ b/sql/sql_tuning.py @@ -13,7 +13,7 @@ def __init__(self, instance_name, db_name, sqltext): instance = Instance.objects.get(instance_name=instance_name) query_engine = get_engine(instance=instance) self.engine = query_engine - self.db_name = self.engine.escape_string(db_name) + self.db_name = db_name self.sqltext = sqltext self.sql_variable = """ select diff --git a/sql/tests.py b/sql/tests.py index 357a1ccea0..5a3e8202c9 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -2539,7 +2539,7 @@ def test_param_edit_variable_not_config( data = { "instance_id": self.master.id, "variable_name": "1", - "runtime_value": "false", + "variable_value": "false", } r = self.client.post(path="/param/edit/", data=data) self.assertEqual( diff --git a/sql_api/api_instance.py b/sql_api/api_instance.py index 6787ca4149..4cb50b51ba 100644 --- a/sql_api/api_instance.py +++ b/sql_api/api_instance.py @@ -187,10 +187,12 @@ def post(self, request): instance = Instance.objects.get(pk=instance_id) try: + # escape + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") + tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") + query_engine = get_engine(instance=instance) - db_name = query_engine.escape_string(db_name) - schema_name = query_engine.escape_string(schema_name) - tb_name = query_engine.escape_string(tb_name) if resource_type == "database": resource = query_engine.get_all_databases() elif resource_type == "schema" and db_name: diff --git a/sql_api/api_workflow.py b/sql_api/api_workflow.py index 480a704e1a..fa702cefb5 100644 --- a/sql_api/api_workflow.py +++ b/sql_api/api_workflow.py @@ -1,4 +1,3 @@ -import MySQLdb from django.contrib.auth.decorators import permission_required from django.utils.decorators import method_decorator from rest_framework import views, generics, status, serializers, permissions @@ -61,11 +60,9 @@ def post(self, request): instance = serializer.get_instance() # 交给engine进行检测 try: - db_name = request.data["db_name"] check_engine = get_engine(instance=instance) - db_name = check_engine.escape_string(db_name) check_result = check_engine.execute_check( - db_name=db_name, sql=request.data["full_sql"].strip() + db_name=request.data["db_name"], sql=request.data["full_sql"].strip() ) except Exception as e: raise serializers.ValidationError({"errors": f"{e}"})