diff --git a/tests/base.py b/tests/base.py index f090f66b162..1dc26c674bf 100644 --- a/tests/base.py +++ b/tests/base.py @@ -5,26 +5,55 @@ """Shared resources for tests.""" -import os import unittest +from functools import lru_cache from typing import Union from detection_rules.rule import TOMLRule from detection_rules.rule_loader import DeprecatedCollection, DeprecatedRule, RuleCollection, production_filter +@lru_cache +def default_rules() -> RuleCollection: + return RuleCollection.default() + + class BaseRuleTest(unittest.TestCase): """Base class for shared test cases which need to load rules""" + RULE_LOADER_FAIL = False + RULE_LOADER_FAIL_MSG = None + RULE_LOADER_FAIL_RAISED = False + @classmethod def setUpClass(cls): - os.environ["DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE"] = "1" - rc = RuleCollection.default() - cls.all_rules = rc.rules - cls.rule_lookup = rc.id_map - cls.production_rules = rc.filter(production_filter) - cls.deprecated_rules: DeprecatedCollection = rc.deprecated + # too noisy; refactor + # os.environ["DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE"] = "1" + + if not cls.RULE_LOADER_FAIL: + try: + rc = default_rules() + cls.all_rules = rc.rules + cls.rule_lookup = rc.id_map + cls.production_rules = rc.filter(production_filter) + cls.deprecated_rules: DeprecatedCollection = rc.deprecated + except Exception as e: + cls.RULE_LOADER_FAIL = True + cls.RULE_LOADER_FAIL_MSG = str(e) @staticmethod - def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=' ->'): + def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=' ->') -> str: return f'{rule.id} - {rule.name}{trailer or ""}' + + def setUp(self) -> None: + if self.RULE_LOADER_FAIL: + # limit the loader failure to just one run + # raise a dedicated test failure for the loader + if not self.RULE_LOADER_FAIL_RAISED: + self.RULE_LOADER_FAIL_RAISED = True + with self.subTest('Test that the rule loader loaded with no validation or other failures.'): + self.fail(f'Rule loader failure: \n{self.RULE_LOADER_FAIL_MSG}') + + self.skipTest('Rule loader failure') + else: + super().setUp()