diff --git a/common/tests.py b/common/tests.py index 2a7a2881e9..bbf139fe53 100644 --- a/common/tests.py +++ b/common/tests.py @@ -337,6 +337,9 @@ def setUpClass(cls): cls.superuser1 = User(username='super1', is_superuser=True) cls.superuser1.save() cls.now = datetime.datetime.now() + cls.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql', + host='testhost', port=3306, user='mysql_user', password='mysql_password') + cls.slave1.save() # 批量创建数据 ddl ,u1 ,g1, yesterday 组, 2 个数据 ddl_workflow = [SqlWorkflow( workflow_name='ddl %s' % i, @@ -346,8 +349,9 @@ def setUpClass(cls): engineer_display=cls.u1.display, audit_auth_groups='some_group', create_time=cls.now - datetime.timedelta(days=1), - status = '已正常结束', - is_backup = '是', + status='workflow_finish', + is_backup='是', + instance=cls.slave1, instance_name='some_instance', db_name='some_db', sql_content='some_sql', @@ -362,8 +366,9 @@ def setUpClass(cls): engineer_display=cls.u2.display, audit_auth_groups='some_group', create_time=cls.now - datetime.timedelta(days=2), - status='已正常结束', + status='workflow_finish', is_backup='是', + instance=cls.slave1, instance_name='some_instance', db_name='some_db', sql_content='some_sql', @@ -375,6 +380,15 @@ def setUpClass(cls): # #) for i in range(20)] + @classmethod + def tearDownClass(cls): + SqlWorkflow.objects.all().delete() + QueryLog.objects.all().delete() + cls.u1.delete() + cls.u2.delete() + cls.superuser1.delete() + cls.slave1.delete() + def testGetDateList(self): dao = ChartDao() end = datetime.date.today() @@ -420,14 +434,6 @@ def testDashboard(self): r = c.get('/dashboard/') self.assertEqual(r.status_code, 200) - @classmethod - def tearDownClass(cls): - SqlWorkflow.objects.all().delete() - QueryLog.objects.all().delete() - cls.u1.delete() - cls.u2.delete() - cls.superuser1.delete() - class AuthTest(TestCase): diff --git a/sql/admin.py b/sql/admin.py index d7b137738b..c5b2eb4bc7 100644 --- a/sql/admin.py +++ b/sql/admin.py @@ -73,7 +73,8 @@ class QueryLogAdmin(admin.ModelAdmin): @admin.register(QueryPrivileges) class QueryPrivilegesAdmin(admin.ModelAdmin): list_display = ( - 'user_display', 'instance_name', 'db_name', 'table_name', 'valid_date', 'limit_num', 'create_time') + # TODO 删除instance_name + 'user_display', 'instance_name', 'instance', 'db_name', 'table_name', 'valid_date', 'limit_num', 'create_time') search_fields = ['user_display', 'instance_name'] list_filter = ('user_display', 'instance_name', 'db_name', 'table_name',) diff --git a/sql/models.py b/sql/models.py index 0524c4cd75..0670750d87 100644 --- a/sql/models.py +++ b/sql/models.py @@ -122,6 +122,7 @@ class SqlWorkflow(models.Model): status = models.CharField(max_length=50, choices=SQL_WORKFLOW_CHOICES) is_backup = models.CharField('是否备份', choices=(('否', '否'), ('是', '是')), max_length=20) review_content = models.TextField('自动审核内容的JSON格式') + # TODO 需要删除instance_name 字段 instance = models.ForeignKey(Instance, on_delete=models.CASCADE) instance_name = models.CharField('实例名称', max_length=50) db_name = models.CharField('数据库', max_length=60) @@ -244,6 +245,8 @@ class QueryPrivilegesApply(models.Model): title = models.CharField('申请标题', max_length=50) user_name = models.CharField('申请人', max_length=30) user_display = models.CharField('申请人中文名', max_length=50, default='') + # TODO 后期删除 instance_name + instance = models.ForeignKey(Instance, on_delete=models.CASCADE) instance_name = models.CharField('实例名称', max_length=50) db_list = models.TextField('数据库') # 逗号分隔的数据库列表 table_list = models.TextField('表') # 逗号分隔的表列表 @@ -270,6 +273,7 @@ class QueryPrivileges(models.Model): privilege_id = models.AutoField(primary_key=True) user_name = models.CharField('用户名', max_length=30) user_display = models.CharField('申请人中文名', max_length=50, default='') + # TODO 后期删除instance_name instance = models.ForeignKey(Instance, on_delete=models.CASCADE) instance_name = models.CharField('实例名称', max_length=50) db_name = models.CharField('数据库', max_length=200) @@ -293,7 +297,6 @@ class Meta: # 记录在线查询sql的日志 class QueryLog(models.Model): - instance = models.ForeignKey(Instance, on_delete=models.CASCADE) instance_name = models.CharField('实例名称', max_length=50) db_name = models.CharField('数据库名称', max_length=30) sqllog = models.TextField('执行的sql查询') @@ -320,6 +323,7 @@ class DataMaskingColumns(models.Model): rule_type = models.IntegerField('规则类型', choices=((1, '手机号'), (2, '证件号码'), (3, '银行卡'), (4, '邮箱'), (5, '金额'), (6, '其他'))) active = models.IntegerField('激活状态', choices=((0, '未激活'), (1, '激活'))) + # TODO 暂时设置为允许为空, 迁移完成后再修改 instance = models.ForeignKey(Instance, on_delete=models.CASCADE) instance_name = models.CharField('实例名称', max_length=50) table_schema = models.CharField('字段所在库名', max_length=64) diff --git a/sql/query.py b/sql/query.py index 7758a1bd22..3917d1e2fb 100644 --- a/sql/query.py +++ b/sql/query.py @@ -45,6 +45,7 @@ def query_audit_call_back(workflow_id, workflow_status): insertlist = [QueryPrivileges( user_name=apply_queryset.user_name, user_display=apply_queryset.user_display, + instance = apply_queryset.instance, instance_name=apply_queryset.instance_name, db_name=db_name, table_name=apply_queryset.table_list, valid_date=apply_queryset.valid_date, limit_num=apply_queryset.limit_num, priv_type=apply_queryset.priv_type) for db_name in @@ -54,6 +55,7 @@ def query_audit_call_back(workflow_id, workflow_status): insertlist = [QueryPrivileges( user_name=apply_queryset.user_name, user_display=apply_queryset.user_display, + instance=apply_queryset.instance, instance_name=apply_queryset.instance_name, db_name=apply_queryset.db_list, table_name=table_name, valid_date=apply_queryset.valid_date, limit_num=apply_queryset.limit_num, priv_type=apply_queryset.priv_type) for table_name in @@ -77,12 +79,12 @@ def query_priv_check(user, instance_name, db_name, sql_content, limit_num): elif re.match(r"^show\s+create\s+table", sql_content.lower()): tb_name = re.sub('^show\s+create\s+table', '', sql_content, count=1, flags=0).strip() # 先判断是否有整库权限 - db_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance_name=instance_name, + db_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance= instance, db_name=db_name, priv_type=1, valid_date__gte=datetime.datetime.now(), is_deleted=0) # 无整库权限再验证表权限 if len(db_privileges) == 0: - tb_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance_name=instance_name, + tb_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance=instance, db_name=db_name, table_name=tb_name, priv_type=2, valid_date__gte=datetime.datetime.now(), is_deleted=0) if len(tb_privileges) == 0: @@ -98,7 +100,7 @@ def query_priv_check(user, instance_name, db_name, sql_content, limit_num): if table_ref_result['status'] == 0: table_ref = table_ref_result['data'] # 获取表信息,校验是否拥有全部表查询权限 - QueryPrivilegesOb = QueryPrivileges.objects.filter(user_name=user.username, instance_name=instance_name) + QueryPrivilegesOb = QueryPrivileges.objects.filter(user_name=user.username, instance=instance) # 先判断是否有整库权限 for table in table_ref: db_privileges = QueryPrivilegesOb.filter(db_name=table['db'], priv_type=1, @@ -116,7 +118,7 @@ def query_priv_check(user, instance_name, db_name, sql_content, limit_num): # 获取表数据报错,检查配置文件是否允许继续执行,并进行库权限校验 else: # 校验库权限,防止inception的语法树打印错误时连库权限也未做校验 - privileges = QueryPrivileges.objects.filter(user_name=user.username, instance_name=instance_name, + privileges = QueryPrivileges.objects.filter(user_name=user.username, instance=instance, db_name=db_name, valid_date__gte=datetime.datetime.now(), is_deleted=0) @@ -134,7 +136,7 @@ def query_priv_check(user, instance_name, db_name, sql_content, limit_num): db_list = [table_info['db'] for table_info in table_ref] table_list = [table_info['table'] for table_info in table_ref] user_limit_num = QueryPrivileges.objects.filter(user_name=user.username, - instance_name=instance_name, + instance=instance, db_name__in=db_list, table_name__in=table_list, valid_date__gte=datetime.datetime.now(), @@ -142,14 +144,14 @@ def query_priv_check(user, instance_name, db_name, sql_content, limit_num): if user_limit_num is None: # 如果表没获取到则获取涉及库的最小limit限制 user_limit_num = QueryPrivileges.objects.filter(user_name=user.username, - instance_name=instance_name, + instance=instance, db_name=db_name, valid_date__gte=datetime.datetime.now(), is_deleted=0 ).aggregate(Min('limit_num'))['limit_num__min'] else: # 如果表没获取到则获取涉及库的最小limit限制 user_limit_num = QueryPrivileges.objects.filter(user_name=user.username, - instance_name=instance_name, + instance=instance, db_name=db_name, valid_date__gte=datetime.datetime.now(), is_deleted=0).aggregate(Min('limit_num'))['limit_num__min'] @@ -239,10 +241,11 @@ def applyforprivileges(request): # 判断是否需要限制到表级别的权限 # 库权限 + ins = Instance.objects.get(instance_name=instance_name) if int(priv_type) == 1: db_list = db_list.split(',') # 检查申请账号是否已拥整个库的查询权限 - own_dbs = QueryPrivileges.objects.filter(instance_name=instance_name, user_name=user.username, + own_dbs = QueryPrivileges.objects.filter(instance=ins, user_name=user.username, db_name__in=db_list, valid_date__gte=datetime.datetime.now(), priv_type=1, is_deleted=0).values('db_name') @@ -259,7 +262,7 @@ def applyforprivileges(request): elif int(priv_type) == 2: table_list = table_list.split(',') # 检查申请账号是否已拥有该表的查询权限 - own_tables = QueryPrivileges.objects.filter(instance_name=instance_name, user_name=user.username, + own_tables = QueryPrivileges.objects.filter(instance=ins, user_name=user.username, db_name=db_name, table_name__in=table_list, valid_date__gte=datetime.datetime.now(), priv_type=2, is_deleted=0).values('table_name') @@ -277,25 +280,26 @@ def applyforprivileges(request): try: with transaction.atomic(): # 保存申请信息到数据库 - applyinfo = QueryPrivilegesApply() - applyinfo.title = title - applyinfo.group_id = group_id - applyinfo.group_name = group_name - applyinfo.audit_auth_groups = Audit.settings(group_id, WorkflowDict.workflow_type['query']) - applyinfo.user_name = user.username - applyinfo.user_display = user.display - applyinfo.instance_name = instance_name + applyinfo = QueryPrivilegesApply( + title=title, + group_id=group_id, + group_name=group_name, + audit_auth_groups=Audit.settings(group_id, WorkflowDict.workflow_type['query']), + user_name=user.username, + user_display=user.display, + instance=ins, + instance_name=instance_name, + priv_type=int(priv_type), + valid_date=valid_date, + status=WorkflowDict.workflow_status['audit_wait'], + limit_num=limit_num + ) if int(priv_type) == 1: applyinfo.db_list = ','.join(db_list) applyinfo.table_list = '' elif int(priv_type) == 2: applyinfo.db_list = db_name applyinfo.table_list = ','.join(table_list) - applyinfo.priv_type = int(priv_type) - applyinfo.valid_date = valid_date - applyinfo.status = WorkflowDict.workflow_status['audit_wait'] # 待审核 - applyinfo.limit_num = limit_num - applyinfo.create_user = user.username applyinfo.save() apply_id = applyinfo.apply_id @@ -331,14 +335,14 @@ def getuserprivileges(request): # 获取用户的权限数据 if user.is_superuser: if user_name != 'all': - privileges_list_obj = QueryPrivileges.objects.all().filter(user_name=user_name, - is_deleted=0, - table_name__contains=search, - valid_date__gte=datetime.datetime.now()) + privileges_list_obj = QueryPrivileges.objects.filter(user_name=user_name, + is_deleted=0, + table_name__contains=search, + valid_date__gte=datetime.datetime.now()) else: - privileges_list_obj = QueryPrivileges.objects.all().filter(is_deleted=0, - table_name__contains=search, - valid_date__gte=datetime.datetime.now()) + privileges_list_obj = QueryPrivileges.objects.filter(is_deleted=0, + table_name__contains=search, + valid_date__gte=datetime.datetime.now()) else: privileges_list_obj = QueryPrivileges.objects.filter(user_name=user.username, table_name__contains=search, diff --git a/sql/tests.py b/sql/tests.py index d6e55212cc..bb7bc04eb0 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -1,7 +1,7 @@ import json from datetime import timedelta, datetime from unittest.mock import MagicMock, patch, ANY - +from unittest import skip from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.models import Group @@ -131,6 +131,7 @@ def setUp(self): group_name='some_group', title='some_title', user_name='some_user', + instance=self.slave1, instance_name='some_ins', db_list='some_db,some_db2', limit_num=100, @@ -145,6 +146,7 @@ def setUp(self): group_name='some_group', title='some_title', user_name='some_user', + instance=self.slave1, instance_name='some_ins', db_list='some_db', table_list='some_table,some_tb2', @@ -158,6 +160,7 @@ def setUp(self): self.db_priv_for_user3 = QueryPrivileges( user_name=self.u3.username, user_display=self.u3.display, + instance=self.slave1, instance_name=self.slave1.instance_name, db_name='some_db', table_name='', @@ -168,6 +171,7 @@ def setUp(self): self.table_priv_for_user3 = QueryPrivileges( user_name=self.u3.username, user_display=self.u3.display, + instance=self.slave1, instance_name=self.slave1.instance_name, db_name='another_db', table_name='some_table', @@ -178,6 +182,7 @@ def setUp(self): self.db_priv_for_user3_another_instance = QueryPrivileges( user_name=self.u3.username, user_display=self.u3.display, + instance=self.slave2, instance_name=self.slave2.instance_name, db_name='some_db_another_instance', table_name='', @@ -187,14 +192,15 @@ def setUp(self): self.db_priv_for_user3_another_instance.save() def tearDown(self): + self.query_apply_1.delete() + self.query_apply_2.delete() + QueryPrivileges.objects.all().delete() self.u1.delete() self.u2.delete() self.u3.delete() + self.superuser1.delete() self.slave1.delete() self.slave2.delete() - self.query_apply_1.delete() - self.query_apply_2.delete() - QueryPrivileges.objects.all().delete() archer_config = SysConfig() archer_config.set('disable_star', False) @@ -355,6 +361,9 @@ def setUp(self): self.u1.save() self.superuser1 = User(username='super1', is_superuser=True) self.superuser1.save() + self.master1 = Instance(instance_name='test_master_instance', type='master', db_type='mysql', + host='testhost', port=3306, user='mysql_user', password='mysql_password') + self.master1.save() self.wf1 = SqlWorkflow( workflow_name='some_name', group_id=1, @@ -365,6 +374,7 @@ def setUp(self): create_time=self.now - timedelta(days=1), status='workflow_finish', is_backup='是', + instance=self.master1, instance_name='some_instance', db_name='some_db', sql_content='some_sql', @@ -385,6 +395,7 @@ def setUp(self): create_time=self.now - timedelta(days=1), status='workflow_manreviewing', is_backup='是', + instance=self.master1, instance_name='some_instance', db_name='some_db', sql_content='some_sql', @@ -397,10 +408,11 @@ def setUp(self): self.wf2.save() def tearDown(self): - self.u1.delete() - self.superuser1.delete() self.wf1.delete() self.wf2.delete() + self.master1.delete() + self.u1.delete() + self.superuser1.delete() def testWorkflowStatus(self): c = Client(header={}) @@ -622,6 +634,9 @@ def setUp(self): self.now = datetime.now() self.u1 = User(username='some_user', display='用户1') self.u1.save() + self.master1 = Instance(instance_name='test_master_instance', type='master', db_type='mysql', + host='testhost', port=3306, user='mysql_user', password='mysql_password') + self.master1.save() self.wf1 = SqlWorkflow( workflow_name='some_name2', group_id=1, @@ -632,6 +647,7 @@ def setUp(self): create_time=self.now - timedelta(days=1), status='workflow_executing', is_backup='是', + instance=self.master1, instance_name='some_instance', db_name='some_db', sql_content='some_sql', @@ -654,6 +670,7 @@ def tearDown(self): self.wf1.delete() self.u1.delete() self.task_result = None + self.master1.delete() @patch('sql.utils.execute_sql.notify_for_execute') @patch('sql.utils.execute_sql.Audit') @@ -736,3 +753,122 @@ def test_analyze_text_not_None(self): r = self.client.post(path='/sql_analyze/analyze/', data={"text": text, "instance_name": instance_name, "db_name": db_name}) self.assertListEqual(list(json.loads(r.content)['rows'][0].keys()), ['sql_id', 'sql', 'report']) + + +class UserQueryPrivilege(TestCase): + + def setUp(self): + self.u1 = User(username='some_user', display='用户1') + self.u1.save() + self.ins1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql', + host='testhost', port=3306, user='mysql_user', password='mysql_password') + self.ins1.save() + self.tomorrow = datetime.now() + timedelta(days=1) + + def tearDown(self): + self.u1.delete() + self.ins1.delete() + QueryPrivileges.objects.all().delete() + + @skip('not implemented') + def test_add_privilege(self): + + self.assertEqual(0, QueryPrivileges.objects.filter( + instance=self.ins1, + db_name='some_db', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=1 + ).count()) + self.u1.add_db_privilege(instance=self.ins1, + db_name='some_db', + limit_num=100, + valid_date=self.tomorrow) + self.assertEqual(1, QueryPrivileges.objects.filter( + instance=self.ins1, + db_name='some_db', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + is_deleted=0, + priv_type=2 + ).count()) + + @skip('not implemented') + def test_get_query_limit(self): + self.assertEqual(0, self.u1.get_db_limit(instance=self.ins1, db_name='some_db')) + new_priv = QueryPrivileges( + instance=self.ins1, + db_name='some_db', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=1) + new_priv.save() + self.assertEqual(new_priv.limit_num, self.u1.get_db_limit(instance=self.ins1, db_name='some_db')) + + self.assertEqual(0, self.u1.get_table_limit(instance=self.ins1, db_name='some_db', table_name='some_table')) + new_table_priv = QueryPrivileges( + instance=self.ins1, + db_name='some_db', + table_name='some_tb', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=2 + ) + new_table_priv.save() + self.assertEqual(new_table_priv.limit_num, + self.u1.get_table_limit(instance=self.ins1, db_name='some_db', table_name='some_table')) + + @skip('not implemented') + def test_revoke_privilege(self): + new_db_priv = QueryPrivileges( + instance=self.ins1, + db_name='some_db', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=1) + new_db_priv.save() + self.u1.revoke_db_query_privilege(instance=self.ins1, db_name='some_db') + self.assertEqual(0, QueryPrivileges.objects.filter( + instance=self.ins1, + db_name='some_db', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=1 + ).count()) + + new_table_priv = QueryPrivileges( + instance=self.ins1, + db_name='some_db', + table_name='some_table', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=2 + ) + new_table_priv.save() + self.u1.revoke_table_query_privilege(instance=self.ins1,db_name='some_db',table_name='some_table') + self.assertEqual(0, QueryPrivileges.objects.filter( + instance=self.ins1, + db_name='some_db', + table_name='some_table', + limit_num=100, + valid_date=self.tomorrow, + user_name=self.u1.username, + priv_type=2 + ).count()) + +class PrivilegeApplyTest(TestCase): + + def setUp(self): + self.u1 = User(username='some_user', display='用户1') + self.u1.save() + self.ins1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql', + host='testhost', port=3306, user='mysql_user', password='mysql_password') + self.ins1.save() + self.tomorrow = datetime.now() + timedelta(days=1) \ No newline at end of file