diff --git a/sql_api/api_workflow.py b/sql_api/api_workflow.py index 29d715957d..38f0a4a018 100644 --- a/sql_api/api_workflow.py +++ b/sql_api/api_workflow.py @@ -58,10 +58,14 @@ def post(self, request): serializer = ExecuteCheckSerializer(data=request.data) serializer.is_valid(raise_exception=True) instance = serializer.get_instance() - check_engine = get_engine(instance=instance) - check_result = check_engine.execute_check( - db_name=request.data["db_name"], sql=request.data["full_sql"].strip() - ) + # 交给engine进行检测 + try: + check_engine = get_engine(instance=instance) + check_result = check_engine.execute_check( + db_name=request.data["db_name"], sql=request.data["full_sql"].strip() + ) + except Exception as e: + raise serializers.ValidationError({"errors": f'{e}'}) review_result_list = [] for r in check_result.rows: review_result_list += [r.__dict__] diff --git a/sql_api/tests.py b/sql_api/tests.py index ea949092b5..607092b20b 100644 --- a/sql_api/tests.py +++ b/sql_api/tests.py @@ -1,10 +1,14 @@ from datetime import datetime, timedelta +from unittest.mock import patch + from django.test import TestCase from django.contrib.auth import get_user_model from django.contrib.auth.models import Group, Permission from rest_framework.test import APITestCase from rest_framework import status from common.config import SysConfig +from sql.engines import ReviewSet +from sql.engines.models import ReviewResult from sql.models import ( ResourceGroup, Instance, @@ -451,6 +455,84 @@ def test_get_workflow_log_list(self): self.assertEqual(r.status_code, status.HTTP_200_OK) self.assertEqual(r.json()["count"], 1) + def test_check_param_is_None(self): + """测试工单检测,参数内容为空""" + json_data = { + "full_sql": "", + "db_name": "test_db", + "instance_id": self.ins.id, + } + r = self.client.post("/api/v1/workflow/sqlcheck/", json_data, format="json") + self.assertEqual(r.status_code, status.HTTP_400_BAD_REQUEST) + + @patch("sql_api.api_workflow.get_engine") + def test_check_inception_Exception(self, _get_engine): + """测试工单检测,inception报错""" + json_data = { + "full_sql": "use mysql", + "db_name": "test_db", + "instance_id": self.ins.id, + } + _get_engine.side_effect = RuntimeError("RuntimeError") + r = self.client.post("/api/v1/workflow/sqlcheck/", json_data, format="json") + print(json.loads(r.content)) + self.assertDictEqual( + json.loads(r.content), {'errors': 'RuntimeError'} + ) + + @patch("sql_api.serializers.get_engine") + def test_check(self, _get_engine): + """测试工单检测,正常返回""" + json_data = { + "full_sql": "use mysql", + "db_name": "test_db", + "instance_id": self.ins.id, + } + column_list = [ + "id", + "stage", + "errlevel", + "stagestatus", + "errormessage", + "sql", + "affected_rows", + "sequence", + "backup_dbname", + "execute_time", + "sqlsha1", + "backup_time", + "actual_affected_rows", + ] + + rows = [ + ReviewResult( + id=1, + stage="CHECKED", + errlevel=0, + stagestatus="Audit Completed", + errormessage="", + sql="use `archer`", + affected_rows=0, + actual_affected_rows=0, + sequence="0_0_00000000", + backup_dbname="", + execute_time="0", + sqlsha1="", + ) + ] + _get_engine.return_value.execute_check.return_value = ReviewSet( + warning_count=0, error_count=0, column_list=column_list, rows=rows + ) + r = self.client.post("/api/v1/workflow/sqlcheck/", json_data, format="json") + self.assertListEqual( + list(json.loads(r.content).keys()), + ['is_execute', 'checked', 'warning', 'error', 'warning_count', 'error_count', 'is_critical', 'syntax_type', + 'rows', 'column_list', 'status', 'affected_rows'] + ) + self.assertListEqual( + list(json.loads(r.content)["rows"][0].keys()), column_list + ) + def test_submit_workflow(self): """测试提交SQL上线工单""" json_data = { @@ -468,6 +550,22 @@ def test_submit_workflow(self): self.assertEqual(r.status_code, status.HTTP_201_CREATED) self.assertEqual(r.json()["workflow"]["workflow_name"], "上线工单1") + def test_submit_param_is_None(self): + """测试SQL提交,参数内容为空""" + json_data = { + "workflow": { + "workflow_name": "上线工单1", + "demand_url": "test", + "group_id": 1, + "db_name": "test_db", + "engineer": self.user.username, + "instance": self.ins.id, + }, + "sql_content": "", + } + r = self.client.post("/api/v1/workflow/", json_data, format="json") + self.assertEqual(r.status_code, status.HTTP_400_BAD_REQUEST) + def test_audit_workflow(self): """测试审核工单""" json_data = {