diff --git a/.github/workflows/esql-validation.yml b/.github/workflows/esql-validation.yml new file mode 100644 index 00000000000..7887f630021 --- /dev/null +++ b/.github/workflows/esql-validation.yml @@ -0,0 +1,114 @@ +name: ES|QL Validation +on: + push: + branches: [ "main", "8.*", "9.*" ] + pull_request: + branches: [ "*" ] + paths: + - 'rules/**/*.toml' +jobs: + build-and-validate: + runs-on: ubuntu-latest + + steps: + - name: Setup Detection Rules + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + with: + fetch-depth: 0 + path: detection-rules + + - name: Check if new or modified rule files are ESQL rules + id: check-esql + run: | + cd detection-rules + + # Check if the event is a push + if [ "${{ github.event_name }}" = "push" ]; then + echo "Triggered by a push event. Setting run_esql=true." + echo "run_esql=true" >> $GITHUB_ENV + exit 0 + fi + + MODIFIED_FILES=$(git diff --name-only --diff-filter=AM HEAD~1 | grep '^rules/.*\.toml$' || true) + if [ -z "$MODIFIED_FILES" ]; then + echo "No modified or new .toml files found. Skipping workflow." + echo "run_esql=false" >> $GITHUB_ENV + exit 0 + fi + + if ! grep -q 'type = "esql"' $MODIFIED_FILES; then + echo "No 'type = \"esql\"' found in the modified .toml files. Skipping workflow." + echo "run_esql=false" >> $GITHUB_ENV + exit 0 + fi + + echo "run_esql=true" >> $GITHUB_ENV + + - name: Check out repository + env: + DR_CLOUD_ID: ${{ secrets.cloud_id }} + DR_API_KEY: ${{ secrets.api_key }} + if: ${{ !env.DR_CLOUD_ID && !env.DR_API_KEY && env.run_esql == 'true' }} + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + with: + path: elastic-container + repository: peasead/elastic-container + + - name: Build and run containers + env: + DR_CLOUD_ID: ${{ secrets.cloud_id }} + DR_API_KEY: ${{ secrets.api_key }} + if: ${{ !env.DR_CLOUD_ID && !env.DR_API_KEY && env.run_esql == 'true' }} + run: | + cd elastic-container + GENERATED_PASSWORD=$(openssl rand -base64 16) + sed -i "s|changeme|$GENERATED_PASSWORD|" .env + echo "::add-mask::$GENERATED_PASSWORD" + echo "GENERATED_PASSWORD=$GENERATED_PASSWORD" >> $GITHUB_ENV + set -x + bash elastic-container.sh start + + - name: Get API Key and setup auth + env: + DR_CLOUD_ID: ${{ secrets.cloud_id }} + DR_API_KEY: ${{ secrets.api_key }} + DR_ELASTICSEARCH_URL: "https://localhost:9200" + ES_USER: "elastic" + ES_PASSWORD: ${{ env.GENERATED_PASSWORD }} + if: ${{ !env.DR_CLOUD_ID && !env.DR_API_KEY && env.run_esql == 'true' }} + run: | + cd detection-rules + response=$(curl -k -X POST -u "$ES_USER:$ES_PASSWORD" -H "Content-Type: application/json" -d '{ + "name": "tmp-api-key", + "expiration": "1d" + }' "$DR_ELASTICSEARCH_URL/_security/api_key") + + DR_API_KEY=$(echo "$response" | jq -r '.encoded') + echo "::add-mask::$DR_API_KEY" + echo "DR_API_KEY=$DR_API_KEY" >> $GITHUB_ENV + + - name: Set up Python 3.13 + if: ${{ env.run_esql == 'true' }} + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6 + with: + python-version: '3.13' + + - name: Install dependencies + if: ${{ env.run_esql == 'true' }} + run: | + cd detection-rules + python -m pip install --upgrade pip + pip cache purge + pip install .[dev] + + - name: Remote Test ESQL Rules + if: ${{ env.run_esql == 'true' }} + env: + DR_CLOUD_ID: ${{ secrets.cloud_id || '' }} + DR_KIBANA_URL: ${{ secrets.cloud_id == '' && 'https://localhost:5601' || '' }} + DR_ELASTICSEARCH_URL: ${{ secrets.cloud_id == '' && 'https://localhost:9200' || '' }} + DR_API_KEY: ${{ secrets.api_key || env.DR_API_KEY }} + DR_IGNORE_SSL_ERRORS: ${{ secrets.cloud_id == '' && 'true' || '' }} + run: | + cd detection-rules + python -m detection_rules dev test esql-remote-validation diff --git a/CLI.md b/CLI.md index d59e025cca5..6fe7485a9f6 100644 --- a/CLI.md +++ b/CLI.md @@ -49,6 +49,10 @@ Using the environment variable `DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION` will byp Using the environment variable `DR_CLI_MAX_WIDTH` will set a custom max width for the click CLI. For instance, some users may want to increase the default value in cases where help messages are cut off. +Using the environment variable `DR_REMOTE_ESQL_VALIDATION` will enable remote ESQL validation for rules that use ESQL queries. This validation will be performed whenever the rule is loaded including for example the view-rule command. This requires the appropriate kibana_url or cloud_id, api_key, and es_url to be set in the config file or as environment variables. + +Using the environment variable `DR_SKIP_EMPTY_INDEX_CLEANUP` will disable the cleanup of remote testing indexes that are created as part of the remote ESQL validation. By default, these indexes are deleted after the validation is complete, or upon validation error. + ## Importing rules into the repo You can import rules into the repo using the `create-rule` or `import-rules-to-repo` commands. Both of these commands will diff --git a/detection_rules/cli_utils.py b/detection_rules/cli_utils.py index c66c243f4b9..27201bd63f9 100644 --- a/detection_rules/cli_utils.py +++ b/detection_rules/cli_utils.py @@ -7,7 +7,10 @@ import datetime import functools import os +import re +import time import typing +import uuid from collections.abc import Callable from pathlib import Path from typing import Any @@ -27,6 +30,104 @@ RULES_CONFIG = parse_rules_config() +def schema_prompt(name: str, value: Any | None = None, is_required: bool = False, **options: Any) -> Any: # noqa: PLR0911, PLR0912, PLR0915 + """Interactively prompt based on schema requirements.""" + field_type = options.get("type") + pattern: str | None = options.get("pattern") + enum = options.get("enum", []) + minimum = int(options["minimum"]) if "minimum" in options else None + maximum = int(options["maximum"]) if "maximum" in options else None + min_item = int(options.get("min_items", 0)) + max_items = int(options.get("max_items", 9999)) + + default = options.get("default") + if default is not None and str(default).lower() in ("true", "false"): + default = str(default).lower() + + if "date" in name: + default = time.strftime("%Y/%m/%d") + + if name == "rule_id": + default = str(uuid.uuid4()) + + if len(enum) == 1 and is_required and field_type not in ("array", ["array"]): + return enum[0] + + def _check_type(_val: Any) -> bool: # noqa: PLR0911 + if field_type in ("number", "integer") and not str(_val).isdigit(): + print(f"Number expected but got: {_val}") + return False + if pattern: + match = re.match(pattern, _val) + if not match or len(match.group(0)) != len(_val): + print(f"{_val} did not match pattern: {pattern}!") + return False + if enum and _val not in enum: + print("{} not in valid options: {}".format(_val, ", ".join(enum))) + return False + if minimum and (type(_val) is int and int(_val) < minimum): + print(f"{_val!s} is less than the minimum: {minimum!s}") + return False + if maximum and (type(_val) is int and int(_val) > maximum): + print(f"{_val!s} is greater than the maximum: {maximum!s}") + return False + if type(_val) is str and field_type == "boolean" and _val.lower() not in ("true", "false"): + print(f"Boolean expected but got: {_val!s}") + return False + return True + + def _convert_type(_val: Any) -> Any: + if field_type == "boolean" and type(_val) is not bool: + _val = _val.lower() == "true" + return int(_val) if field_type in ("number", "integer") else _val + + prompt = ( + "{name}{default}{required}{multi}".format( + name=name, + default=f' [{default}] ("n/a" to leave blank) ' if default else "", + required=" (required) " if is_required else "", + multi=(" (multi, comma separated) " if field_type in ("array", ["array"]) else ""), + ).strip() + + ": " + ) + + while True: + result = value or input(prompt) or default + if result == "n/a": + result = None + + if not result: + if is_required: + value = None + continue + return None + + if field_type in ("array", ["array"]): + result_list = result.split(",") + + if not (min_item < len(result_list) < max_items): + if is_required: + value = None + break + return [] + + for value in result_list: + if not _check_type(value): + if is_required: + value = None # noqa: PLW2901 + break + return [] + if is_required and value is None: + continue + return [_convert_type(r) for r in result_list] + if _check_type(result): + return _convert_type(result) + if is_required: + value = None + continue + return None + + def single_collection(f: Callable[..., Any]) -> Callable[..., Any]: """Add arguments to get a RuleCollection by file, directory or a list of IDs""" from .misc import raise_client_error @@ -145,7 +246,6 @@ def rule_prompt( # noqa: PLR0912, PLR0913, PLR0915 **kwargs: Any, ) -> TOMLRule | str: """Prompt loop to build a rule.""" - from .misc import schema_prompt additional_required = additional_required or [] creation_date = datetime.date.today().strftime("%Y/%m/%d") # noqa: DTZ011 diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 19d059eafe9..f56b1d4840d 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -25,7 +25,8 @@ import pytoml # type: ignore[reportMissingTypeStubs] import requests.exceptions import yaml -from elasticsearch import Elasticsearch +from elasticsearch import BadRequestError, Elasticsearch +from elasticsearch import ConnectionError as ESConnectionError from eql.table import Table # type: ignore[reportMissingTypeStubs] from eql.utils import load_dump # type: ignore[reportMissingTypeStubs, reportUnknownVariableType] from kibana.connector import Kibana # type: ignore[reportMissingTypeStubs] @@ -39,6 +40,7 @@ from .docs import REPO_DOCS_DIR, IntegrationSecurityDocs, IntegrationSecurityDocsMDX from .ecs import download_endpoint_schemas, download_schemas from .endgame import EndgameSchemaManager +from .esql_errors import EsqlKibanaBaseError, EsqlSchemaError, EsqlSyntaxError, EsqlTypeMismatchError from .eswrap import CollectEvents, add_range_to_dsl from .ghwrap import GithubClient, update_gist from .integrations import ( @@ -50,7 +52,13 @@ load_integrations_manifests, ) from .main import root -from .misc import PYTHON_LICENSE, add_client, raise_client_error +from .misc import ( + PYTHON_LICENSE, + add_client, + get_default_elasticsearch_client, + get_default_kibana_client, + raise_client_error, +) from .packaging import CURRENT_RELEASE_PATH, PACKAGE_FILE, RELEASE_DIR, Package from .rule import ( AnyRuleData, @@ -63,6 +71,7 @@ TOMLRuleContents, ) from .rule_loader import RuleCollection, production_filter +from .rule_validators import ESQLValidator from .schemas import definitions, get_stack_versions from .utils import check_version_lock_double_bumps, dict_hash, get_etc_path, get_path from .version_lock import VersionLockFile, loaded_version_lock @@ -1403,6 +1412,74 @@ def rule_event_search( # noqa: PLR0913 raise_client_error("Rule is not a query rule!") +@test_group.command("esql-remote-validation") +@click.option( + "--verbosity", + type=click.IntRange(0, 1), + default=0, + help="Set verbosity level: 0 for minimal output, 1 for detailed output.", +) +def esql_remote_validation( + verbosity: int, +) -> None: + """Search using a rule file against an Elasticsearch instance.""" + + rule_collection: RuleCollection = RuleCollection.default().filter(production_filter) + esql_rules = [r for r in rule_collection if r.contents.data.type == "esql"] + + click.echo(f"ESQL rules loaded: {len(esql_rules)}") + + if not esql_rules: + return + # TODO(eric-forte-elastic): @add_client https://github.com/elastic/detection-rules/issues/5156 # noqa: FIX002 + with get_default_kibana_client() as kibana_client, get_default_elasticsearch_client() as elastic_client: + if not kibana_client or not elastic_client: + raise_client_error("Skipping remote validation due to missing client") + + failed_count = 0 + fail_list: list[str] = [] + max_retries = 3 + for r in esql_rules: + retry_count = 0 + while retry_count < max_retries: + try: + validator = ESQLValidator(r.contents.data.query) # type: ignore[reportIncompatibleMethodOverride] + _ = validator.remote_validate_rule_contents(kibana_client, elastic_client, r.contents, verbosity) + break + except ( + ValueError, + BadRequestError, + EsqlSchemaError, + EsqlSyntaxError, + EsqlTypeMismatchError, + EsqlKibanaBaseError, + ) as e: + click.echo(f"FAILURE: {e}") + fail_list.append(f"{r.contents.data.rule_id} FAILURE: {type(e)}: {e}") + failed_count += 1 + break + except ESConnectionError as e: + retry_count += 1 + click.echo(f"Connection error: {e}. Retrying {retry_count}/{max_retries}...") + time.sleep(30) + if retry_count == max_retries: + click.echo(f"FAILURE: {e} after {max_retries} retries") + fail_list.append(f"FAILURE: {e} after {max_retries} retries") + failed_count += 1 + + click.echo(f"Total rules: {len(esql_rules)}") + click.echo(f"Failed rules: {failed_count}") + + _ = Path("failed_rules.log").write_text("\n".join(fail_list), encoding="utf-8") + click.echo("Failed rules written to failed_rules.log") + if failed_count > 0: + click.echo("Failed rule IDs:") + uuids = {line.split()[0] for line in fail_list} + click.echo("\n".join(uuids)) + ctx = click.get_current_context() + ctx.exit(1) + + @test_group.command("rule-survey") @click.argument("query", required=False) @click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") diff --git a/detection_rules/esql.py b/detection_rules/esql.py new file mode 100644 index 00000000000..a14cc91004d --- /dev/null +++ b/detection_rules/esql.py @@ -0,0 +1,60 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +"""ESQL exceptions.""" + +import re +from dataclasses import dataclass + + +@dataclass +class EventDataset: + """Dataclass for event.dataset with integration and datastream parts.""" + + package: str + integration: str + + def __str__(self) -> str: + return f"{self.package}.{self.integration}" + + +def get_esql_query_event_dataset_integrations(query: str) -> list[EventDataset]: + """Extract event.dataset and data_stream.dataset integrations from an ES|QL query.""" + number_of_parts = 2 + # Regex patterns for event.dataset, and data_stream.dataset + # This mimics the logic in get_datasets_and_modules but for ES|QL as we do not have an ast + + regex_patterns = { + "in": [ + re.compile(r"event\.dataset\s+in\s*\(\s*([^)]+)\s*\)"), + re.compile(r"data_stream\.dataset\s+in\s*\(\s*([^)]+)\s*\)"), + ], + "eq": [ + re.compile(r'event\.dataset\s*==\s*"([^"]+)"'), + re.compile(r'data_stream\.dataset\s*==\s*"([^"]+)"'), + ], + } + + # Extract datasets + datasets: list[str] = [] + for regex_list in regex_patterns.values(): + for regex in regex_list: + matches = regex.findall(query) + if matches: + for match in matches: + if "," in match: + # Handle `in` case with multiple values + datasets.extend([ds.strip().strip('"') for ds in match.split(",")]) + else: + # Handle `==` case + datasets.append(match.strip().strip('"')) + + event_datasets: list[EventDataset] = [] + for dataset in datasets: + parts = dataset.split(".") + if len(parts) == number_of_parts: # Ensure there are exactly two parts + event_datasets.append(EventDataset(package=parts[0], integration=parts[1])) + + return event_datasets diff --git a/detection_rules/esql_errors.py b/detection_rules/esql_errors.py new file mode 100644 index 00000000000..5a43729cabb --- /dev/null +++ b/detection_rules/esql_errors.py @@ -0,0 +1,76 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +"""ESQL exceptions.""" + +from elasticsearch import Elasticsearch # type: ignore[reportMissingTypeStubs] + +from .misc import getdefault + +__all__ = ( + "EsqlKibanaBaseError", + "EsqlSchemaError", + "EsqlSemanticError", + "EsqlSyntaxError", + "EsqlTypeMismatchError", + "EsqlUnknownIndexError", + "EsqlUnsupportedTypeError", +) + + +def cleanup_empty_indices( + elastic_client: Elasticsearch, index_patterns: tuple[str, ...] = ("rule-test-*", "test-*") +) -> None: + """Delete empty indices matching the given patterns.""" + if getdefault("skip_empty_index_cleanup")(): + return + for pattern in index_patterns: + indices = elastic_client.cat.indices(index=pattern, format="json") + empty_indices = [index["index"] for index in indices if index["docs.count"] == "0"] # type: ignore[reportMissingTypeStubs] + for empty_index in empty_indices: + _ = elastic_client.indices.delete(index=empty_index) + + +class EsqlKibanaBaseError(Exception): + """Base class for ESQL exceptions with cleanup logic.""" + + def __init__(self, message: str, elastic_client: Elasticsearch) -> None: + cleanup_empty_indices(elastic_client) + super().__init__(message) + + +class EsqlSchemaError(EsqlKibanaBaseError): + """Error in ESQL schema. Validated via Kibana until AST is available.""" + + +class EsqlUnsupportedTypeError(EsqlKibanaBaseError): + """Error in ESQL type validation using unsupported type.""" + + +class EsqlSyntaxError(EsqlKibanaBaseError): + """Error with ESQL syntax.""" + + +class EsqlTypeMismatchError(Exception): + """Error when validating types in ESQL. Can occur in stack or local schema comparison.""" + + def __init__(self, message: str, elastic_client: Elasticsearch | None = None) -> None: + if elastic_client: + cleanup_empty_indices(elastic_client) + super().__init__(message) + + +class EsqlSemanticError(Exception): + """Error with ESQL semantics. Validated through regex enforcement.""" + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class EsqlUnknownIndexError(Exception): + """Error with ESQL Indices. Validated through regex enforcement.""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/detection_rules/index_mappings.py b/detection_rules/index_mappings.py new file mode 100644 index 00000000000..729d224fc8a --- /dev/null +++ b/detection_rules/index_mappings.py @@ -0,0 +1,457 @@ +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0; you may not use this file except in compliance with the Elastic License +# 2.0. + +"""Validation logic for rules containing queries.""" + +import re +import time +from collections.abc import Callable +from copy import deepcopy +from typing import Any + +from elastic_transport import ObjectApiResponse +from elasticsearch import Elasticsearch # type: ignore[reportMissingTypeStubs] +from elasticsearch.exceptions import BadRequestError +from semver import Version + +from . import ecs, integrations, misc, utils +from .config import load_current_package_version +from .esql import EventDataset +from .esql_errors import ( + EsqlKibanaBaseError, + EsqlSchemaError, + EsqlSyntaxError, + EsqlTypeMismatchError, + EsqlUnknownIndexError, + EsqlUnsupportedTypeError, + cleanup_empty_indices, +) +from .integrations import ( + load_integrations_manifests, + load_integrations_schemas, +) +from .rule import RuleMeta +from .schemas import get_stack_schemas +from .schemas.definitions import HTTP_STATUS_BAD_REQUEST +from .utils import combine_dicts + + +def delete_nested_key_from_dict(d: dict[str, Any], compound_key: str) -> None: + """Delete a nested key from a dictionary.""" + keys = compound_key.split(".") + for key in keys[:-1]: + if key in d and isinstance(d[key], dict): + d = d[key] # type: ignore[reportUnknownVariableType] + else: + return + d.pop(keys[-1], None) + + +def flat_schema_to_index_mapping(flat_schema: dict[str, str]) -> dict[str, Any]: + """ + Convert dicts with flat JSON paths and values into a nested mapping with + intermediary `properties`, `fields` and `type` fields. + """ + + # Sorting here ensures that 'a.b' processed before 'a.b.c', allowing us to correctly + # detect and handle multi-fields. + sorted_items = sorted(flat_schema.items()) + result = {} + + for field_path, field_type in sorted_items: + parts = field_path.split(".") + current_level = result + + for part in parts[:-1]: + node = current_level.setdefault(part, {}) # type: ignore[reportUnknownVariableType] + + if "type" in node and node["type"] not in ("nested", "object"): + current_level = node.setdefault("fields", {}) # type: ignore[reportUnknownVariableType] + else: + current_level = node.setdefault("properties", {}) # type: ignore[reportUnknownVariableType] + + leaf_key = parts[-1] + current_level[leaf_key] = {"type": field_type} + + # add `scaling_factor` field missing in the schema + # https://www.elastic.co/docs/reference/elasticsearch/mapping-reference/number#scaled-float-params + if field_type == "scaled_float": + current_level[leaf_key]["scaling_factor"] = 1000 + + # add `path` field for `alias` fields, set to a dummy value + if field_type == "alias": + current_level[leaf_key]["path"] = "@timestamp" + + return result # type: ignore[reportUnknownVariableType] + + +def get_rule_integrations(metadata: RuleMeta) -> list[str]: + """Retrieve rule integrations from metadata.""" + if metadata.integration: + rule_integrations: list[str] = ( + metadata.integration if isinstance(metadata.integration, list) else [metadata.integration] + ) + else: + rule_integrations: list[str] = [] + return rule_integrations + + +def create_index_with_index_mapping( + elastic_client: Elasticsearch, index_name: str, mappings: dict[str, Any] +) -> ObjectApiResponse[Any] | None: + """Create an index with the specified mappings and settings to support large number of fields and nested objects.""" + try: + return elastic_client.indices.create( + index=index_name, + mappings={"properties": mappings}, + settings={ + "index.mapping.total_fields.limit": 10000, + "index.mapping.nested_fields.limit": 500, + "index.mapping.nested_objects.limit": 10000, + }, + ) + except BadRequestError as e: + error_message = str(e) + if ( + e.status_code == HTTP_STATUS_BAD_REQUEST + and "validation_exception" in error_message + and "Validation Failed: 1: this action would add [2] shards" in error_message + ): + cleanup_empty_indices(elastic_client) + try: + return elastic_client.indices.create( + index=index_name, + mappings={"properties": mappings}, + settings={ + "index.mapping.total_fields.limit": 10000, + "index.mapping.nested_fields.limit": 500, + "index.mapping.nested_objects.limit": 10000, + }, + ) + except BadRequestError as retry_error: + raise EsqlSchemaError(str(retry_error), elastic_client) from retry_error + raise EsqlSchemaError(error_message, elastic_client) from e + + +def get_existing_mappings(elastic_client: Elasticsearch, indices: list[str]) -> tuple[dict[str, Any], dict[str, Any]]: + """Retrieve mappings for all matching existing index templates.""" + existing_mappings: dict[str, Any] = {} + index_lookup: dict[str, Any] = {} + for index in indices: + index_tmpl_mappings = get_simulated_index_template_mappings(elastic_client, index) + index_lookup[index] = index_tmpl_mappings + combine_dicts(existing_mappings, index_tmpl_mappings) + return existing_mappings, index_lookup + + +def get_simulated_index_template_mappings(elastic_client: Elasticsearch, name: str) -> dict[str, Any]: + """ + Return the mappings from the index configuration that would be applied + to the specified index from an existing index template + + https://elasticsearch-py.readthedocs.io/en/stable/api/indices.html#elasticsearch.client.IndicesClient.simulate_index_template + """ + template = elastic_client.indices.simulate_index_template(name=name) + if not template: + return {} + return template["template"]["mappings"]["properties"] + + +def prepare_integration_mappings( # noqa: PLR0913 + rule_integrations: list[str], + event_dataset_integrations: list[EventDataset], + package_manifests: Any, + integration_schemas: Any, + stack_version: str, + log: Callable[[str], None], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Prepare integration mappings for the given rule integrations.""" + integration_mappings: dict[str, Any] = {} + index_lookup: dict[str, Any] = {} + dataset_restriction: dict[str, str] = {} + + # Process restrictions, note we need this for loops to be separate + for event_dataset in event_dataset_integrations: + # Ensure the integration is in rule_integrations + if event_dataset.package not in rule_integrations: + dataset_restriction.setdefault(event_dataset.package, []).append(event_dataset.integration) # type: ignore[reportIncompatibleMethodOverride] + for event_dataset in event_dataset_integrations: + if event_dataset.package not in rule_integrations: + rule_integrations.append(event_dataset.package) + + for integration in rule_integrations: + package = integration + package_version, _ = integrations.find_latest_compatible_version( + package, + "", + Version.parse(stack_version), + package_manifests, + ) + package_schema = integration_schemas[package][package_version] + + # Apply dataset restrictions if any + if integration in dataset_restriction: + allowed_keys = dataset_restriction[integration] + package_schema = {key: value for key, value in package_schema.items() if key in allowed_keys} + + for stream in package_schema: + flat_schema = package_schema[stream] + stream_mappings = flat_schema_to_index_mapping(flat_schema) + nested_multifields = find_nested_multifields(stream_mappings) + for field in nested_multifields: + field_name = str(field).split(".fields.")[0].replace(".", ".properties.") + ".fields" + log( + f"Warning: Nested multi-field `{field}` found in `{integration}-{stream}`. " + f"Removing parent field from schema for ES|QL validation." + ) + delete_nested_key_from_dict(stream_mappings, field_name) + nested_flattened_fields = find_flattened_fields_with_subfields(stream_mappings) + for field in nested_flattened_fields: + field_name = str(field).split(".fields.")[0].replace(".", ".properties.") + ".fields" + log( + f"Warning: flattened field `{field}` found in `{integration}-{stream}` with sub fields. " + f"Removing parent field from schema for ES|QL validation." + ) + delete_nested_key_from_dict(stream_mappings, field_name) + utils.combine_dicts(integration_mappings, deepcopy(stream_mappings)) + index_lookup[f"{integration}-{stream}"] = stream_mappings + + return integration_mappings, index_lookup + + +def get_filtered_index_schema( + indices: list[str], + index_lookup: dict[str, Any], + ecs_schema: dict[str, Any], + non_ecs_mapping: dict[str, Any], + custom_mapping: dict[str, Any], +) -> dict[str, Any]: + """Check if the provided indices are known based on the integration format. Returns the combined schema.""" + + non_ecs_indices = ecs.get_non_ecs_schema() + custom_indices = ecs.get_custom_schemas() + + # Assumes valid index format is logs-.* or logs-.-* + filtered_keys = {"logs-" + key.replace("-", ".") + "*" for key in index_lookup if key not in indices} + filtered_keys.update({"logs-" + key.replace("-", ".") + "-*" for key in index_lookup if key not in indices}) + # Replace "logs-endpoint." with "logs-endpoint.events." + filtered_keys = { + key.replace("logs-endpoint.", "logs-endpoint.events.") if "logs-endpoint." in key else key + for key in filtered_keys + } + filtered_keys.update(non_ecs_indices.keys()) + filtered_keys.update(custom_indices.keys()) + filtered_keys.add("logs-endpoint.alerts-*") + + matches: list[str] = [] + for index in indices: + pattern = re.compile(index.replace(".", r"\.").replace("*", ".*").rstrip("-")) + matches = [key for key in filtered_keys if pattern.fullmatch(key)] + + if not matches: + raise EsqlUnknownIndexError( + f"Unknown index pattern(s): {', '.join(indices)}. Known patterns: {', '.join(filtered_keys)}" + ) + + filtered_index_lookup = { + "logs-" + key.replace("-", ".") + "*": value for key, value in index_lookup.items() if key not in indices + } + filtered_index_lookup.update( + {"logs-" + key.replace("-", ".") + "-*": value for key, value in index_lookup.items() if key not in indices} + ) + filtered_index_lookup = { + key.replace("logs-endpoint.", "logs-endpoint.events."): value for key, value in filtered_index_lookup.items() + } + filtered_index_lookup.update(non_ecs_mapping) + filtered_index_lookup.update(custom_mapping) + + combined_mappings: dict[str, Any] = {} + for match in matches: + utils.combine_dicts(combined_mappings, deepcopy(filtered_index_lookup.get(match, {}))) + + utils.combine_dicts(combined_mappings, deepcopy(ecs_schema)) + return combined_mappings + + +def create_remote_indices( + elastic_client: Elasticsearch, + existing_mappings: dict[str, Any], + index_lookup: dict[str, Any], + log: Callable[[str], None], +) -> str: + """Create remote indices for validation and return the index string.""" + + suffix = str(int(time.time() * 1000)) + test_index = f"rule-test-index-{suffix}" + response = create_index_with_index_mapping(elastic_client, test_index, existing_mappings) + log(f"Index `{test_index}` created: {response}") + full_index_str = test_index + + # create all integration indices + for index, properties in index_lookup.items(): + ind_index_str = f"test-{index.rstrip('*')}{suffix}" + response = create_index_with_index_mapping(elastic_client, ind_index_str, properties) + log(f"Index `{ind_index_str}` created: {response}") + full_index_str = f"{full_index_str}, {ind_index_str}" + + return full_index_str + + +def execute_query_against_indices( + elastic_client: Elasticsearch, + query: str, + test_index_str: str, + log: Callable[[str], None], + delete_indices: bool = True, +) -> tuple[list[Any], ObjectApiResponse[Any]]: + """Execute the ESQL query against the test indices on a remote Stack and return the columns.""" + try: + log(f"Executing a query against `{test_index_str}`") + response = elastic_client.esql.query(query=query) + log(f"Got query response: {response}") + query_columns = response.get("columns", []) + except BadRequestError as e: + error_msg = str(e) + if "parsing_exception" in error_msg: + raise EsqlSyntaxError(str(e), elastic_client) from e + if "Unknown column" in error_msg: + raise EsqlSchemaError(str(e), elastic_client) from e + if "verification_exception" in error_msg and "unsupported type" in error_msg: + raise EsqlUnsupportedTypeError(str(e), elastic_client) from e + if "verification_exception" in error_msg: + raise EsqlTypeMismatchError(str(e), elastic_client) from e + raise EsqlKibanaBaseError(str(e), elastic_client) from e + if delete_indices or not misc.getdefault("skip_empty_index_cleanup")(): + for index_str in test_index_str.split(","): + response = elastic_client.indices.delete(index=index_str.strip()) + log(f"Test index `{index_str}` deleted: {response}") + + query_column_names = [c["name"] for c in query_columns] + log(f"Got query columns: {', '.join(query_column_names)}") + return query_columns, response + + +def find_nested_multifields(mapping: dict[str, Any], path: str = "") -> list[Any]: + """Recursively search for nested multi-fields in Elasticsearch mappings.""" + nested_multifields = [] + + for field, properties in mapping.items(): + current_path = f"{path}.{field}" if path else field + + if isinstance(properties, dict): + # Check if the field has a `fields` key + if "fields" in properties: + # Check if any subfield in `fields` also has a `fields` key + for subfield, subproperties in properties["fields"].items(): # type: ignore[reportUnknownVariableType] + if isinstance(subproperties, dict) and "fields" in subproperties: + nested_multifields.append(f"{current_path}.fields.{subfield}") # type: ignore[reportUnknownVariableType] + + # Recurse into subfields + if "properties" in properties: + nested_multifields.extend( # type: ignore[reportUnknownVariableType] + find_nested_multifields(properties["properties"], current_path) # type: ignore[reportUnknownVariableType] + ) + + return nested_multifields # type: ignore[reportUnknownVariableType] + + +def find_flattened_fields_with_subfields(mapping: dict[str, Any], path: str = "") -> list[str]: + """Recursively search for fields of type 'flattened' that have a 'fields' key in Elasticsearch mappings.""" + flattened_fields_with_subfields: list[str] = [] + + for field, properties in mapping.items(): + current_path = f"{path}.{field}" if path else field + + if isinstance(properties, dict): + # Check if the field is of type 'flattened' and has a 'fields' key + if properties.get("type") == "flattened" and "fields" in properties: # type: ignore[reportUnknownVariableType] + flattened_fields_with_subfields.append(current_path) # type: ignore[reportUnknownVariableType] + + # Recurse into subfields + if "properties" in properties: + flattened_fields_with_subfields.extend( # type: ignore[reportUnknownVariableType] + find_flattened_fields_with_subfields(properties["properties"], current_path) # type: ignore[reportUnknownVariableType] + ) + + return flattened_fields_with_subfields + + +def get_ecs_schema_mappings(current_version: Version) -> dict[str, Any]: + """Get the ECS schema in an index mapping format (nested schema) handling scaled floats.""" + ecs_version = get_stack_schemas()[str(current_version)]["ecs"] + ecs_schemas = ecs.get_schemas() + ecs_schema_flattened: dict[str, Any] = {} + ecs_schema_scaled_floats: dict[str, Any] = {} + for index, info in ecs_schemas[ecs_version]["ecs_flat"].items(): + if info["type"] == "scaled_float": + ecs_schema_scaled_floats.update({index: info["scaling_factor"]}) + ecs_schema_flattened.update({index: info["type"]}) + ecs_schema = utils.convert_to_nested_schema(ecs_schema_flattened) + for index, info in ecs_schema_scaled_floats.items(): + parts = index.split(".") + current = ecs_schema + + # Traverse the ecs_schema to the correct nested dictionary + for part in parts[:-1]: # Traverse all parts except the last one + current = current.setdefault(part, {}).setdefault("properties", {}) + + current[parts[-1]].update({"scaling_factor": info}) + return ecs_schema + + +def prepare_mappings( # noqa: PLR0913 + elastic_client: Elasticsearch, + indices: list[str], + event_dataset_integrations: list[EventDataset], + metadata: RuleMeta, + stack_version: str, + log: Callable[[str], None], +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Prepare index mappings for the given indices and rule integrations.""" + existing_mappings, index_lookup = get_existing_mappings(elastic_client, indices) + + # Collect mappings for the integrations + rule_integrations = get_rule_integrations(metadata) + + # Collect mappings for all relevant integrations for the given stack version + package_manifests = load_integrations_manifests() + integration_schemas = load_integrations_schemas() + + integration_mappings, integration_index_lookup = prepare_integration_mappings( + rule_integrations, event_dataset_integrations, package_manifests, integration_schemas, stack_version, log + ) + + index_lookup.update(integration_index_lookup) + + # Load non-ecs schema and convert to index mapping format (nested schema) + non_ecs_mapping: dict[str, Any] = {} + non_ecs = ecs.get_non_ecs_schema() + for index in indices: + non_ecs_mapping.update(non_ecs.get(index, {})) + non_ecs_mapping = ecs.flatten(non_ecs_mapping) + non_ecs_mapping = utils.convert_to_nested_schema(non_ecs_mapping) + + # Load custom schema and convert to index mapping format (nested schema) + custom_mapping: dict[str, Any] = {} + custom_indices = ecs.get_custom_schemas() + for index in indices: + custom_mapping.update(custom_indices.get(index, {})) + custom_mapping = ecs.flatten(custom_mapping) + custom_mapping = utils.convert_to_nested_schema(custom_mapping) + + # Load ECS in an index mapping format (nested schema) + current_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) + ecs_schema = get_ecs_schema_mappings(current_version) + + # Filter combined mappings based on the provided indices + combined_mappings = get_filtered_index_schema(indices, index_lookup, ecs_schema, non_ecs_mapping, custom_mapping) + + index_lookup.update({"rule-ecs-index": ecs_schema}) + + if (not integration_mappings or existing_mappings) and not non_ecs_mapping and not ecs_schema: + raise ValueError("No mappings found") + index_lookup.update({"rule-non-ecs-index": non_ecs_mapping}) + + return existing_mappings, index_lookup, combined_mappings diff --git a/detection_rules/integrations.py b/detection_rules/integrations.py index 5c96667296e..da8c34a24af 100644 --- a/detection_rules/integrations.py +++ b/detection_rules/integrations.py @@ -225,7 +225,7 @@ def find_latest_compatible_version( rule_stack_version: Version, packages_manifest: dict[str, Any], ) -> tuple[str, list[str]]: - """Finds least compatible version for specified integration based on stack version supplied.""" + """Finds latest compatible version for specified integration based on stack version supplied.""" if not package: raise ValueError("Package must be specified") @@ -430,6 +430,7 @@ def collect_schema_fields( def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> list[dict[str, Any]]: """Parses datasets into packaged integrations from rule data.""" packaged_integrations: list[dict[str, Any]] = [] + # FIXME @eric-forte-elastic: evaluate using EventDataset dataclass for parsing # noqa: FIX001, TD001, TD003 for _value in sorted(datasets): # cleanup extra quotes pulled from ast field value = _value.strip('"') diff --git a/detection_rules/main.py b/detection_rules/main.py index 7de190520d2..879b252e413 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -30,10 +30,17 @@ from .config import load_current_package_version, parse_rules_config from .exception import TOMLExceptionContents, build_exception_objects, parse_exceptions_results_from_api from .generic_loader import GenericCollection -from .misc import add_client, nested_set, parse_user_config, raise_client_error -from .rule import DeprecatedRule, QueryRuleData, TOMLRule, TOMLRuleContents +from .misc import ( + add_client, + getdefault, + nested_set, + parse_user_config, + raise_client_error, +) +from .rule import DeprecatedRule, ESQLRuleData, QueryRuleData, RuleMeta, TOMLRule, TOMLRuleContents from .rule_formatter import toml_write from .rule_loader import RawRuleCollection, RuleCollection, update_metadata_from_file +from .rule_validators import ESQLValidator from .schemas import all_versions, definitions, get_incompatible_fields, get_schema_file from .utils import ( Ndjson, @@ -446,10 +453,21 @@ def mass_update( @root.command("view-rule") @click.argument("rule-file", type=Path) @click.option("--api-format/--rule-format", default=True, help="Print the rule in final api or rule format") +@click.option("--esql-remote-validation", is_flag=True, default=False, help="Enable remote validation for the rule") @click.pass_context -def view_rule(_: click.Context, rule_file: Path, api_format: str) -> TOMLRule | DeprecatedRule: +def view_rule( + _: click.Context, rule_file: Path, api_format: str, esql_remote_validation: bool +) -> TOMLRule | DeprecatedRule: """View an internal rule or specified rule file.""" rule = RuleCollection().load_file(rule_file) + if ( + esql_remote_validation + and isinstance(rule.contents.data, ESQLRuleData) + and isinstance(rule.contents.data.validator, ESQLValidator) + and isinstance(rule.contents.metadata, RuleMeta) + and not getdefault("remote_esql_validation")() + ): + rule.contents.data.validator.validate(rule.contents.data, rule.contents.metadata, force_remote_validation=True) if api_format: click.echo(json.dumps(rule.contents.to_api_format(), indent=2, sort_keys=True)) diff --git a/detection_rules/misc.py b/detection_rules/misc.py index 7992639ea72..064e60c2912 100644 --- a/detection_rules/misc.py +++ b/detection_rules/misc.py @@ -6,10 +6,7 @@ """Misc support.""" import os -import re -import time import unittest -import uuid from collections.abc import Callable from functools import wraps from pathlib import Path @@ -107,104 +104,6 @@ def nest_from_dot(dots: str, value: Any) -> Any: return nested -def schema_prompt(name: str, value: Any | None = None, is_required: bool = False, **options: Any) -> Any: # noqa: PLR0911, PLR0912, PLR0915 - """Interactively prompt based on schema requirements.""" - field_type = options.get("type") - pattern: str | None = options.get("pattern") - enum = options.get("enum", []) - minimum = int(options["minimum"]) if "minimum" in options else None - maximum = int(options["maximum"]) if "maximum" in options else None - min_item = int(options.get("min_items", 0)) - max_items = int(options.get("max_items", 9999)) - - default = options.get("default") - if default is not None and str(default).lower() in ("true", "false"): - default = str(default).lower() - - if "date" in name: - default = time.strftime("%Y/%m/%d") - - if name == "rule_id": - default = str(uuid.uuid4()) - - if len(enum) == 1 and is_required and field_type not in ("array", ["array"]): - return enum[0] - - def _check_type(_val: Any) -> bool: # noqa: PLR0911 - if field_type in ("number", "integer") and not str(_val).isdigit(): - print(f"Number expected but got: {_val}") - return False - if pattern: - match = re.match(pattern, _val) - if not match or len(match.group(0)) != len(_val): - print(f"{_val} did not match pattern: {pattern}!") - return False - if enum and _val not in enum: - print("{} not in valid options: {}".format(_val, ", ".join(enum))) - return False - if minimum and (type(_val) is int and int(_val) < minimum): - print(f"{_val!s} is less than the minimum: {minimum!s}") - return False - if maximum and (type(_val) is int and int(_val) > maximum): - print(f"{_val!s} is greater than the maximum: {maximum!s}") - return False - if type(_val) is str and field_type == "boolean" and _val.lower() not in ("true", "false"): - print(f"Boolean expected but got: {_val!s}") - return False - return True - - def _convert_type(_val: Any) -> Any: - if field_type == "boolean" and type(_val) is not bool: - _val = _val.lower() == "true" - return int(_val) if field_type in ("number", "integer") else _val - - prompt = ( - "{name}{default}{required}{multi}".format( - name=name, - default=f' [{default}] ("n/a" to leave blank) ' if default else "", - required=" (required) " if is_required else "", - multi=(" (multi, comma separated) " if field_type in ("array", ["array"]) else ""), - ).strip() - + ": " - ) - - while True: - result = value or input(prompt) or default - if result == "n/a": - result = None - - if not result: - if is_required: - value = None - continue - return None - - if field_type in ("array", ["array"]): - result_list = result.split(",") - - if not (min_item < len(result_list) < max_items): - if is_required: - value = None - break - return [] - - for value in result_list: - if not _check_type(value): - if is_required: - value = None # noqa: PLW2901 - break - return [] - if is_required and value is None: - continue - return [_convert_type(r) for r in result_list] - if _check_type(result): - return _convert_type(result) - if is_required: - value = None - continue - return None - - def get_kibana_rules_map(repo: str = "elastic/kibana", branch: str = "master") -> dict[str, Any]: """Get list of available rules from the Kibana repo and return a list of URLs.""" @@ -355,6 +254,9 @@ def get_elasticsearch_client( # noqa: PLR0913 **kwargs: Any, ) -> Elasticsearch: """Get an authenticated elasticsearch client.""" + # Handle empty strings as None + cloud_id = cloud_id or None + elasticsearch_url = elasticsearch_url or None if not (cloud_id or elasticsearch_url): raise_client_error("Missing required --cloud-id or --elasticsearch-url") @@ -385,6 +287,16 @@ def get_elasticsearch_client( # noqa: PLR0913 return client +def get_default_elasticsearch_client() -> Elasticsearch: + """Get an default authenticated elasticsearch client.""" + return get_elasticsearch_client( + api_key=getdefault("api_key")(), + cloud_id=getdefault("cloud_id")(), + elasticsearch_url=getdefault("elasticsearch_url")(), + ignore_ssl_errors=getdefault("ignore_ssl_errors")(), + ) + + def get_kibana_client( *, api_key: str, @@ -402,6 +314,17 @@ def get_kibana_client( return Kibana(cloud_id=cloud_id, kibana_url=kibana_url, space=space, verify=verify, api_key=api_key, **kwargs) +def get_default_kibana_client() -> Kibana: + """Get an default authenticated Kibana client.""" + return get_kibana_client( + api_key=getdefault("api_key")(), + cloud_id=getdefault("cloud_id")(), + kibana_url=getdefault("kibana_url")(), + space=getdefault("space")(), + ignore_ssl_errors=getdefault("ignore_ssl_errors")(), + ) + + client_options = { "kibana": { "kibana_url": click.Option(["--kibana-url"], default=getdefault("kibana_url")), diff --git a/detection_rules/remote_validation.py b/detection_rules/remote_validation.py index 90c8d1a24f2..36fc78acffa 100644 --- a/detection_rules/remote_validation.py +++ b/detection_rules/remote_validation.py @@ -19,6 +19,7 @@ from .config import load_current_package_version from .misc import ClientError, get_elasticsearch_client, get_kibana_client, getdefault from .rule import TOMLRule, TOMLRuleContents +from .rule_validators import ESQLValidator from .schemas import definitions @@ -177,25 +178,40 @@ def request(c: TOMLRuleContents) -> None: return responses # type: ignore[reportUnknownVariableType] - def validate_esql(self, contents: TOMLRuleContents) -> dict[str, Any]: + def validate_esql(self, contents: TOMLRuleContents, index_replacement: bool = False) -> dict[str, Any]: + """Validate query for "esql" rule types. Optionally replace indices and use ESQLValidator.""" query = contents.data.query # type: ignore[reportAttributeAccessIssue] rule_id = contents.data.rule_id - headers = {"accept": "application/json", "content-type": "application/json"} - body = {"query": f"{query} | LIMIT 0"} if not self.es_client: raise ValueError("No ES client found") - try: - response = self.es_client.perform_request( - "POST", - "/_query", - headers=headers, - params={"pretty": True}, - body=body, - ) - except Exception as exc: - if isinstance(exc, elasticsearch.BadRequestError): - raise ValidationError(f"ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}") from exc - raise Exception(f"ES|QL query failed for rule: {rule_id}, query: \n{query}") from exc # noqa: TRY002 + if not self.kibana_client: + raise ValueError("No Kibana client found") + + if index_replacement: + try: + validator = ESQLValidator(contents.data.query) # type: ignore[reportIncompatibleMethodOverride] + response = validator.remote_validate_rule_contents(self.kibana_client, self.es_client, contents) + except Exception as exc: + if isinstance(exc, elasticsearch.BadRequestError): + raise ValidationError(f"ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}") from exc + raise Exception(f"ES|QL query failed for rule: {rule_id}, query: \n{query}") from exc # noqa: TRY002 + else: + headers = {"accept": "application/json", "content-type": "application/json"} + body = {"query": f"{query} | LIMIT 0"} + if not self.es_client: + raise ValueError("No ES client found") + try: + response = self.es_client.perform_request( + "POST", + "/_query", + headers=headers, + params={"pretty": True}, + body=body, + ) + except Exception as exc: + if isinstance(exc, elasticsearch.BadRequestError): + raise ValidationError(f"ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}") from exc + raise Exception(f"ES|QL query failed for rule: {rule_id}, query: \n{query}") from exc # noqa: TRY002 return response.body diff --git a/detection_rules/rule.py b/detection_rules/rule.py index e80cd6848b8..92a65b6882f 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -29,6 +29,8 @@ from . import beats, ecs, endgame, utils from .config import load_current_package_version, parse_rules_config +from .esql import get_esql_query_event_dataset_integrations +from .esql_errors import EsqlSemanticError from .integrations import ( find_least_compatible_version, get_integration_schema_fields, @@ -650,8 +652,6 @@ def validate(self, _: "QueryRuleData", __: RuleMeta) -> None: @cached def get_required_fields(self, index: str) -> list[dict[str, Any]]: """Retrieves fields needed for the query along with type information from the schema.""" - if isinstance(self, ESQLValidator): - return [] current_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) ecs_version = get_stack_schemas()[str(current_version)]["ecs"] @@ -665,7 +665,9 @@ def get_required_fields(self, index: str) -> list[dict[str, Any]]: # construct integration schemas packages_manifest = load_integrations_manifests() integrations_schemas = load_integrations_schemas() - datasets, _ = beats.get_datasets_and_modules(self.ast) + datasets: set[str] = set() + if self.ast: + datasets, _ = beats.get_datasets_and_modules(self.ast) package_integrations = parse_datasets(list(datasets), packages_manifest) int_schema: dict[str, Any] = {} data = {"notify": False} @@ -693,6 +695,9 @@ def get_required_fields(self, index: str) -> list[dict[str, Any]]: elif endgame_schema: field_type = endgame_schema.endgame_schema.get(fld, None) + if not field_type and isinstance(self, ESQLValidator): + field_type = self.get_unique_field_type(fld) + required.append({"name": fld, "type": field_type or "unknown", "ecs": is_ecs}) return sorted(required, key=lambda f: f["name"]) @@ -959,7 +964,7 @@ class ESQLRuleData(QueryRuleData): def validates_esql_data(self, data: dict[str, Any], **_: Any) -> None: """Custom validation for query rule type and subclasses.""" if data.get("index"): - raise ValidationError("Index is not a valid field for ES|QL rule type.") + raise EsqlSemanticError("Index is not a valid field for ES|QL rule type.") # Convert the query string to lowercase to handle case insensitivity query_lower = data["query"].lower() @@ -977,7 +982,7 @@ def validates_esql_data(self, data: dict[str, Any], **_: Any) -> None: # Ensure that non-aggregate queries have metadata if not combined_pattern.search(query_lower): - raise ValidationError( + raise EsqlSemanticError( f"Rule: {data['name']} contains a non-aggregate query without" f" metadata fields '_id', '_version', and '_index' ->" f" Add 'metadata _id, _version, _index' to the from command or add an aggregate function." @@ -987,7 +992,7 @@ def validates_esql_data(self, data: dict[str, Any], **_: Any) -> None: # Match | followed by optional whitespace/newlines and then 'keep' keep_pattern = re.compile(r"\|\s*keep\b", re.IGNORECASE | re.DOTALL) if not keep_pattern.search(query_lower): - raise ValidationError( + raise EsqlSemanticError( f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query." ) @@ -1500,21 +1505,22 @@ def get_packaged_integrations( ) -> list[dict[str, Any]] | None: packaged_integrations: list[dict[str, Any]] = [] datasets, _ = beats.get_datasets_and_modules(data.get("ast") or []) # type: ignore[reportArgumentType] - + if isinstance(data, ESQLRuleData): + dataset_objs = get_esql_query_event_dataset_integrations(data.query) + datasets.update(str(obj) for obj in dataset_objs) # integration is None to remove duplicate references upstream in Kibana # chronologically, event.dataset, data_stream.dataset is checked for package:integration, then rule tags # if both exist, rule tags are only used if defined in definitions for non-dataset packages # of machine learning analytic packages - rule_integrations = meta.get("integration", []) - if rule_integrations: - for integration in rule_integrations: - ineligible_integrations = [ - *definitions.NON_DATASET_PACKAGES, - *map(str.lower, definitions.MACHINE_LEARNING_PACKAGES), - ] - if integration in ineligible_integrations or isinstance(data, MachineLearningRuleData): - packaged_integrations.append({"package": integration, "integration": None}) + rule_integrations: str | list[str] = meta.get("integration") or [] + for integration in rule_integrations: + ineligible_integrations = [ + *definitions.NON_DATASET_PACKAGES, + *map(str.lower, definitions.MACHINE_LEARNING_PACKAGES), + ] + if integration in ineligible_integrations or isinstance(data, MachineLearningRuleData): + packaged_integrations.append({"package": integration, "integration": None}) packaged_integrations.extend(parse_datasets(list(datasets), package_manifest)) @@ -1828,7 +1834,7 @@ def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> lis else: package = value - if package in list(package_manifest): + if package in package_manifest: packaged_integrations.append({"package": package, "integration": integration}) return packaged_integrations diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index 23d63f77840..c23b67592c3 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -15,6 +15,8 @@ import eql # type: ignore[reportMissingTypeStubs] import kql # type: ignore[reportMissingTypeStubs] +from elastic_transport import ObjectApiResponse +from elasticsearch import Elasticsearch # type: ignore[reportMissingTypeStubs] from eql import ast # type: ignore[reportMissingTypeStubs] from eql.parser import ( # type: ignore[reportMissingTypeStubs] KvTree, @@ -23,16 +25,29 @@ TypeHint, ) from eql.parser import _parse as base_parse # type: ignore[reportMissingTypeStubs] -from marshmallow import ValidationError +from kibana import Kibana # type: ignore[reportMissingTypeStubs] from semver import Version -from . import ecs, endgame +from . import ecs, endgame, misc, utils from .beats import get_datasets_and_modules, parse_beats_from_index from .config import CUSTOM_RULES_DIR, load_current_package_version, parse_rules_config from .custom_schemas import update_auto_generated_schema -from .integrations import get_integration_schema_data, load_integrations_manifests, parse_datasets +from .esql import get_esql_query_event_dataset_integrations +from .esql_errors import EsqlTypeMismatchError +from .index_mappings import ( + create_remote_indices, + execute_query_against_indices, + get_rule_integrations, + prepare_mappings, +) +from .integrations import ( + get_integration_schema_data, + load_integrations_manifests, + parse_datasets, +) from .rule import EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, set_eql_config -from .schemas import get_stack_schemas +from .schemas import get_latest_stack_version, get_stack_schemas, get_stack_versions +from .schemas.definitions import FROM_SOURCES_REGEX EQL_ERROR_TYPES = ( eql.EqlCompileError @@ -394,6 +409,7 @@ def build_validation_plan(self, data: "QueryRuleData", meta: RuleMeta) -> list[V # Helper for union-by-stack integration targets def add_accumulated_integration_targets(query_text: str, packaged: list[dict[str, Any]], context: str) -> None: + """Add integration-based validation targets by accumulating schemas per stack version.""" combined_by_stack: dict[str, dict[str, Any]] = {} ecs_by_stack: dict[str, str] = {} packages_by_stack: dict[str, set[str]] = {} @@ -527,11 +543,7 @@ def add_stack_targets(query_text: str, include_endgame: bool) -> None: add_stack_targets(synthetic_sequence, include_endgame=False) else: # Datasetless subquery: try metadata integrations first, else add per-subquery stack targets - meta_integrations = meta.integration - if isinstance(meta_integrations, str): - meta_integrations = [meta_integrations] - elif meta_integrations is None: - meta_integrations = [] + meta_integrations = get_rule_integrations(meta) if meta_integrations: meta_pkg_ints = [ @@ -713,28 +725,202 @@ def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) - class ESQLValidator(QueryValidator): """Validate specific fields for ESQL query event types.""" - @cached_property - def ast(self) -> None: # type: ignore[reportIncompatibleMethodOverride] + kibana_client: Kibana + elastic_client: Elasticsearch + metadata: RuleMeta + rule_id: str + verbosity: int + esql_unique_fields: list[dict[str, str]] + + def log(self, val: str) -> None: + """Log if verbosity is 1 or greater (1 corresponds to `-v` in pytest)""" + unit_test_verbose_level = 1 + if self.verbosity >= unit_test_verbose_level: + print(f"{self.rule_id}:", val) + + @property + def ast(self) -> Any: + """Return the AST of the ESQL query. Dependant in ESQL parser which is not implemented""" + # Needs to return none to prevent not implemented error return None @cached_property def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethodOverride] - """Return a list of unique fields in the query.""" - # return empty list for ES|QL rules until ast is available (friendlier than raising error) + """Return a list of unique fields in the query. Requires remote validation to have occurred.""" + esql_unique_fields = getattr(self, "esql_unique_fields", None) + if esql_unique_fields: + return [field["name"] for field in self.esql_unique_fields] return [] - def validate(self, _: "QueryRuleData", __: RuleMeta) -> None: # type: ignore[reportIncompatibleMethodOverride] + def get_esql_query_indices(self, query: str) -> tuple[str, list[str]]: + """Extract indices from an ES|QL query.""" + match = FROM_SOURCES_REGEX.search(query) + + if not match: + return "", [] + + sources_str = match.group("sources") + return sources_str, [source.strip() for source in sources_str.split(",")] + + def get_unique_field_type(self, field_name: str) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + """Get the type of the unique field. Requires remote validation to have occurred.""" + esql_unique_fields = getattr(self, "esql_unique_fields", []) + for field in esql_unique_fields: + if field["name"] == field_name: + return field["type"] + return None + + def validate_columns_index_mapping( + self, query_columns: list[dict[str, str]], combined_mappings: dict[str, Any], version: str = "", query: str = "" + ) -> bool: + """Validate that the columns in the ESQL query match the provided mappings.""" + mismatched_columns: list[str] = [] + + for column in query_columns: + column_name = column["name"] + # Skip Dynamic fields + if column_name.startswith(("Esql.", "Esql_priv.")): + continue + # Skip internal fields + if column_name in ("_id", "_version", "_index"): + continue + # Skip implicit fields + if column_name not in query: + continue + column_type = column["type"] + + # Check if the column exists in combined_mappings or a valid field generated from a function or operator + keys = column_name.split(".") + schema_type = utils.get_column_from_index_mapping_schema(keys, combined_mappings) + schema_type = kql.parser.elasticsearch_type_family(schema_type) if schema_type else None + + # If it is in the schema, but Kibana returns unsupported + if schema_type and column_type == "unsupported": + continue + + # Validate the type + if not schema_type or column_type != schema_type: + # Attempt reverse mapping as for our purposes they are equivalent. + # We are generally concerned about the operators for the types not the values themselves. + reverse_col_type = kql.parser.elasticsearch_type_family(column_type) if column_type else None + if reverse_col_type is not None and schema_type is not None and reverse_col_type == schema_type: + continue + mismatched_columns.append( + f"Dynamic field `{column_name}` is not correctly mapped. " + f"If not dynamic: expected from schema: `{schema_type}`, got from Kibana: `{column_type}`." + ) + + if mismatched_columns: + raise EsqlTypeMismatchError( + f"Column validation errors in Stack Version {version}:\n" + "\n".join(mismatched_columns) + ) + + return True + + def validate(self, data: "QueryRuleData", rule_meta: RuleMeta, force_remote_validation: bool = False) -> None: # type: ignore[reportIncompatibleMethodOverride] """Validate an ESQL query while checking TOMLRule.""" - # temporarily override to NOP until ES|QL query parsing is supported + if misc.getdefault("remote_esql_validation")() or force_remote_validation: + resolved_kibana_options = { + str(option.name): option.default() if callable(option.default) else option.default + for option in misc.kibana_options + if option.name is not None + } + + resolved_elastic_options = { + option.name: option.default() if callable(option.default) else option.default + for option in misc.elasticsearch_options + if option.name is not None + } + + with ( + misc.get_kibana_client(**resolved_kibana_options) as kibana_client, # type: ignore[reportUnknownVariableType] + misc.get_elasticsearch_client(**resolved_elastic_options) as elastic_client, # type: ignore[reportUnknownVariableType] + ): + _ = self.remote_validate_rule( + kibana_client, + elastic_client, + data.query, + rule_meta, + data.rule_id, + ) - def validate_integration( + def remote_validate_rule_contents( + self, kibana_client: Kibana, elastic_client: Elasticsearch, contents: TOMLRuleContents, verbosity: int = 0 + ) -> ObjectApiResponse[Any]: + """Remote validate a rule's ES|QL query using an Elastic Stack.""" + return self.remote_validate_rule( + kibana_client=kibana_client, + elastic_client=elastic_client, + query=contents.data.query, # type: ignore[reportUnknownVariableType] + metadata=contents.metadata, + rule_id=contents.data.rule_id, + verbosity=verbosity, + ) + + def remote_validate_rule( # noqa: PLR0913 self, - _: QueryRuleData, - __: RuleMeta, - ___: list[dict[str, Any]], - ) -> ValidationError | None | ValueError: - # Disabling self.validate(data, meta) - pass + kibana_client: Kibana, + elastic_client: Elasticsearch, + query: str, + metadata: RuleMeta, + rule_id: str = "", + verbosity: int = 0, + ) -> ObjectApiResponse[Any]: + """Uses remote validation from an Elastic Stack to validate ES|QL a given rule""" + + self.rule_id = rule_id + self.verbosity = verbosity + + # Validate that all fields (columns) are either dynamic fields or correctly mapped + # against the combined mapping of all the indices + kibana_details: dict[str, Any] = kibana_client.get("/api/status", {}) # type: ignore[reportUnknownVariableType] + if "version" not in kibana_details: + raise ValueError("Failed to retrieve Kibana details.") + stack_version = get_latest_stack_version() + + self.log(f"Validating against {stack_version} stack") + indices_str, indices = self.get_esql_query_indices(query) # type: ignore[reportUnknownVariableType] + self.log(f"Extracted indices from query: {', '.join(indices)}") + + event_dataset_integrations = get_esql_query_event_dataset_integrations(query) + self.log(f"Extracted Event Dataset integrations from query: {', '.join(indices)}") + + # Get mappings for all matching existing index templates + existing_mappings, index_lookup, combined_mappings = prepare_mappings( + elastic_client, indices, event_dataset_integrations, metadata, stack_version, self.log + ) + self.log(f"Collected mappings: {len(existing_mappings)}") + self.log(f"Combined mappings prepared: {len(combined_mappings)}") + + # Create remote indices + full_index_str = create_remote_indices(elastic_client, existing_mappings, index_lookup, self.log) + + # Replace all sources with the test indices + query = query.replace(indices_str, full_index_str) # type: ignore[reportUnknownVariableType] + + query_columns, response = execute_query_against_indices(elastic_client, query, full_index_str, self.log) # type: ignore[reportUnknownVariableType] + self.esql_unique_fields = query_columns + + # Build a mapping lookup for all stack versions to validate against. + # We only need to check against the schemas locally for the type + # mismatch error, as the EsqlSchemaError and EsqlSyntaxError errors from the stack + # will not be impacted by the difference in schema type mapping. + mappings_lookup: dict[str, dict[str, Any]] = {stack_version: combined_mappings} + versions = get_stack_versions() + for version in versions: + if version in mappings_lookup: + continue + _, _, combined_mappings = prepare_mappings( + elastic_client, indices, event_dataset_integrations, metadata, version, self.log + ) + mappings_lookup[version] = combined_mappings + + for version, mapping in mappings_lookup.items(): + self.log(f"Validating {rule_id} against {version} stack") + if not self.validate_columns_index_mapping(query_columns, mapping, version=version, query=query): + self.log("Dynamic column(s) have improper formatting.") + + return response def extract_error_field(source: str, exc: eql.EqlParseError | kql.KqlParseError) -> str | None: diff --git a/detection_rules/schemas/__init__.py b/detection_rules/schemas/__init__.py index c1fc8b4b7af..62228158a0a 100644 --- a/detection_rules/schemas/__init__.py +++ b/detection_rules/schemas/__init__.py @@ -395,6 +395,12 @@ def get_stack_versions(drop_patch: bool = False) -> list[str]: return versions +def get_latest_stack_version(drop_patch: bool = False) -> str: + """Get the latest defined and supported stack version.""" + parsed_versions = [Version.parse(version) for version in get_stack_versions(drop_patch=drop_patch)] + return str(max(parsed_versions)) + + @cached def get_min_supported_stack_version() -> Version: """Get the minimum defined and supported stack version.""" diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index ac2966e2afd..0fd2be2e4ed 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -56,6 +56,7 @@ def validator_wrapper(value: Any) -> Any: return validator_wrapper +HTTP_STATUS_BAD_REQUEST = 400 ASSET_TYPE = "security_rule" SAVED_OBJECT_TYPE = "security-rule" @@ -75,6 +76,7 @@ def validator_wrapper(value: Any) -> Any: CONDITION_VERSION_PATTERN = re.compile(rf"^\^{_version}$") VERSION_PATTERN = f"^{_version}$" MINOR_SEMVER = re.compile(r"^\d+\.\d+$") +FROM_SOURCES_REGEX = re.compile(r"^\s*FROM\s+(?P.+?)\s*(?:\||\bmetadata\b|//|$)", re.IGNORECASE | re.MULTILINE) BRANCH_PATTERN = f"{VERSION_PATTERN}|^master$" ELASTICSEARCH_EQL_FEATURES = { "allow_negation": (Version.parse("8.9.0"), None), diff --git a/detection_rules/utils.py b/detection_rules/utils.py index 0f84dcd8d97..1a67fed1e37 100644 --- a/detection_rules/utils.py +++ b/detection_rules/utils.py @@ -528,3 +528,41 @@ def get_identifiers(self) -> list[str]: # another group we're not expecting raise ValueError("Unrecognized named group in pattern", self.pattern) return ids + + +def convert_to_nested_schema(flat_schemas: dict[str, str]) -> dict[str, Any]: + """Convert a flat schema to a nested schema with 'properties' for each sub-key.""" + # NOTE this is needed to conform to Kibana's index mapping format + nested_schema = {} + + for key, value in flat_schemas.items(): + parts = key.split(".") + current_level = nested_schema + + for part in parts[:-1]: + current_level = current_level.setdefault(part, {}).setdefault("properties", {}) # type: ignore[reportUnknownVariableType] + + current_level[parts[-1]] = {"type": value} + + return nested_schema # type: ignore[reportUnknownVariableType] + + +def combine_dicts(dest: dict[Any, Any], src: dict[Any, Any]) -> None: + """Combine two dictionaries recursively.""" + for k, v in src.items(): + if k in dest and isinstance(dest[k], dict) and isinstance(v, dict): + combine_dicts(dest[k], v) # type: ignore[reportUnknownVariableType] + else: + dest[k] = v + + +def get_column_from_index_mapping_schema(keys: list[str], current_schema: dict[str, Any] | None) -> str | None: + """Recursively traverse the schema to find the type of the column.""" + key = keys[0] + if not current_schema: + return None + column = current_schema.get(key) or {} # type: ignore[reportUnknownVariableType] + column_type = column.get("type") if column else None # type: ignore[reportUnknownVariableType] + if len(keys) > 1: + return get_column_from_index_mapping_schema(keys[1:], current_schema=column.get("properties")) # type: ignore[reportUnknownVariableType] + return column_type # type: ignore[reportUnknownVariableType] diff --git a/hunting/definitions.py b/hunting/definitions.py index 01e519958e3..717196b6712 100644 --- a/hunting/definitions.py +++ b/hunting/definitions.py @@ -59,5 +59,5 @@ def validate_esql_query(self, query: str) -> None: # Check if either "stats by" or "| keep" exists in the query if not stats_by_pattern.search(query) and not keep_pattern.search(query): raise ValueError( - f"Hunt: {self.name} contains an ES|QL query that mustcontain either 'stats by' or 'keep' functions." + f"Hunt: {self.name} contains an ES|QL query that must contain either 'stats by' or 'keep' functions" ) diff --git a/pyproject.toml b/pyproject.toml index d8d1747fc57..f91d63f270c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.4.11" +version = "1.5.0" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12" diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py index 11ff1c36be3..a505b7ef863 100644 --- a/tests/test_rules_remote.py +++ b/tests/test_rules_remote.py @@ -4,20 +4,103 @@ # 2.0. import unittest +from copy import deepcopy -from detection_rules.misc import get_default_config -from detection_rules.remote_validation import RemoteValidator +import pytest + +from detection_rules.esql_errors import EsqlSchemaError, EsqlSyntaxError, EsqlTypeMismatchError +from detection_rules.misc import ( + get_default_config, + getdefault, +) +from detection_rules.rule_loader import RuleCollection +from detection_rules.utils import get_path, load_rule_contents from .base import BaseRuleTest +MAX_RETRIES = 3 + @unittest.skipIf(get_default_config() is None, "Skipping remote validation due to missing config") +@unittest.skipIf( + not getdefault("remote_esql_validation")(), "Skipping remote validation because remote_esql_validation is False" +) class TestRemoteRules(BaseRuleTest): """Test rules against a remote Elastic stack instance.""" - @unittest.skip("Temporarily disabled") - def test_esql_rules(self): - """Temporarily explicitly test all ES|QL rules remotely pending parsing lib.""" - esql_rules = [r for r in self.all_rules if r.contents.data.type == "esql"] - rv = RemoteValidator(parse_config=True) - rv.validate_rules(esql_rules) + def test_esql_related_integrations(self): + """Test an ESQL rule has its related integrations built correctly.""" + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + production_rule = deepcopy(original_production_rule)[0] + production_rule["metadata"]["integration"] = ["aws"] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where @timestamp > now() - 30 minutes + and event.dataset in ("aws.cloudtrail", "aws.billing") + and aws.cloudtrail.user_identity.arn is not null + and aws.cloudtrail.user_identity.type == "IAMUser" + | keep + aws.cloudtrail.user_identity.type + """ + rule = RuleCollection().load_dict(production_rule) + related_integrations = rule.contents.to_api_format()["related_integrations"] + for integration in related_integrations: + assert integration["package"] == "aws", f"Expected 'aws', but got {integration['package']}" + + def test_esql_event_dataset_schema_error(self): + """Test an ESQL rules that uses event.dataset field in the query validated the fields correctly.""" + # EsqlSchemaError + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + # Test that a ValidationError is raised if the query doesn't match the schema + production_rule = deepcopy(original_production_rule)[0] + del production_rule["metadata"]["integration"] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where @timestamp > now() - 30 minutes + and event.dataset in ("aws.billing") + and aws.cloudtrail.user_identity.type == "IAMUser" + | keep + aws.cloudtrail.user_identity.type + """ + with pytest.raises(EsqlSchemaError): + _ = RuleCollection().load_dict(production_rule) + + def test_esql_type_mismatch_error(self): + """Test an ESQL rules that uses event.dataset field in the query validated the fields correctly.""" + # EsqlSchemaError + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + # Test that a ValidationError is raised if the query doesn't match the schema + production_rule = deepcopy(original_production_rule)[0] + production_rule["metadata"]["integration"] = ["aws"] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where @timestamp > now() - 30 minutes + and event.dataset in ("aws.cloudtrail", "aws.billing") + and aws.cloudtrail.user_identity.type == 5 + | keep + aws.cloudtrail.user_identity.type + """ + with pytest.raises(EsqlTypeMismatchError): + _ = RuleCollection().load_dict(production_rule) + + def test_esql_syntax_error(self): + """Test an ESQL rules that uses event.dataset field in the query validated the fields correctly.""" + # EsqlSchemaError + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) + original_production_rule = load_rule_contents(file_path) + # Test that a ValidationError is raised if the query doesn't match the schema + production_rule = deepcopy(original_production_rule)[0] + production_rule["metadata"]["integration"] = ["aws"] + production_rule["rule"]["query"] = """ + from logs-aws.cloudtrail* metadata _id, _version, _index + | where @timestamp > now() - 30 minutes + and event.dataset in ("aws.cloudtrail", "aws.billing") + and aws.cloudtrail.user_identity.type = "IAMUser" + | keep + aws.cloudtrail.user_identity.type + """ + with pytest.raises(EsqlSyntaxError): + _ = RuleCollection().load_dict(production_rule) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 5148c3d2b39..1215e4b0fbc 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -18,6 +18,7 @@ from detection_rules import utils from detection_rules.config import load_current_package_version +from detection_rules.esql_errors import EsqlSemanticError from detection_rules.rule import TOMLRuleContents from detection_rules.rule_loader import RuleCollection from detection_rules.schemas import RULES_CONFIG, downgrade @@ -315,7 +316,7 @@ def test_esql_data_validation(self): """Test ESQL rule data validation""" # A random ESQL rule to deliver a test query - rule_path = Path("rules/windows/defense_evasion_posh_obfuscation_index_reversal.toml") + rule_path = Path("tests/data/command_control_dummy_production_rule.toml") rule_body = rule_path.read_text() rule_dict = pytoml.loads(rule_body) @@ -323,7 +324,7 @@ def test_esql_data_validation(self): query = """ FROM logs-windows.powershell_operational* METADATA _id, _version, _index | WHERE event.code == "4104" - | KEEP event.count + | KEEP event.code """ rule_dict["rule"]["query"] = query _ = RuleCollection().load_dict(rule_dict, path=rule_path) @@ -333,23 +334,23 @@ def test_esql_data_validation(self): query = """ FROM logs-windows.powershell_operational* METADATA _id, _index, _version | WHERE event.code == "4104" - | KEEP event.count + | KEEP event.code """ rule_dict["rule"]["query"] = query _ = RuleCollection().load_dict(rule_dict, path=rule_path) # Different metadata fields - with pytest.raises(ValidationError): + with pytest.raises(EsqlSemanticError): query = """ FROM logs-windows.powershell_operational* METADATA _foo, _index | WHERE event.code == "4104" - | KEEP event.count + | KEEP event.code """ rule_dict["rule"]["query"] = query _ = RuleCollection().load_dict(rule_dict, path=rule_path) # Missing `keep` - with pytest.raises(ValidationError): + with pytest.raises(EsqlSemanticError): query = """ FROM logs-windows.powershell_operational* METADATA _id, _index, _version | WHERE event.code == "4104"