diff --git a/src/main/java/org/opensearch/securityanalytics/model/Detector.java b/src/main/java/org/opensearch/securityanalytics/model/Detector.java index ecf655f78..58e0b06c4 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Detector.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Detector.java @@ -4,6 +4,8 @@ */ package org.opensearch.securityanalytics.model; +import java.util.HashMap; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.ParseField; @@ -49,6 +51,8 @@ public class Detector implements Writeable, ToXContentObject { public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ENABLED_TIME_FIELD = "enabled_time"; public static final String ALERTING_MONITOR_ID = "monitor_id"; + + public static final String BUCKET_MONITOR_ID_RULE_ID = "bucket_monitor_id_rule_id"; private static final String RULE_TOPIC_INDEX = "rule_topic_index"; private static final String ALERTS_INDEX = "alert_index"; @@ -59,6 +63,9 @@ public class Detector implements Writeable, ToXContentObject { public static final String DETECTORS_INDEX = ".opensearch-detectors-config"; + // Used as a key in rule-monitor map for the purpose of easy detection of the doc level monitor + public static final String DOC_LEVEL_MONITOR = "-1"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Detector.class, new ParseField(DETECTOR_TYPE), @@ -90,6 +97,8 @@ public class Detector implements Writeable, ToXContentObject { private List monitorIds; + private Map ruleIdMonitorId; + private String ruleIndex; private String alertsIndex; @@ -108,7 +117,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule Instant lastUpdateTime, Instant enabledTime, DetectorType detectorType, User user, List inputs, List triggers, List monitorIds, String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern, - String findingsIndex, String findingsIndexPattern) { + String findingsIndex, String findingsIndexPattern, Map rulePerMonitor) { this.type = DETECTOR_TYPE; this.id = id != null ? id : NO_ID; @@ -129,6 +138,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule this.alertsHistoryIndexPattern = alertsHistoryIndexPattern; this.findingsIndex = findingsIndex; this.findingsIndexPattern = findingsIndexPattern; + this.ruleIdMonitorId = rulePerMonitor; if (enabled) { Objects.requireNonNull(enabledTime); @@ -154,7 +164,9 @@ public Detector(StreamInput sin) throws IOException { sin.readString(), sin.readString(), sin.readString(), - sin.readString()); + sin.readString(), + sin.readMap(StreamInput::readString, StreamInput::readString) + ); } @Override @@ -186,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeStringCollection(monitorIds); out.writeString(ruleIndex); + + out.writeMap(ruleIdMonitorId, StreamOutput::writeString, StreamOutput::writeString); } public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { @@ -268,6 +282,7 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten } builder.field(ALERTING_MONITOR_ID, monitorIds); + builder.field(BUCKET_MONITOR_ID_RULE_ID, ruleIdMonitorId); builder.field(RULE_TOPIC_INDEX, ruleIndex); builder.field(ALERTS_INDEX, alertsIndex); builder.field(ALERTS_HISTORY_INDEX, alertsHistoryIndex); @@ -312,6 +327,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws List inputs = new ArrayList<>(); List triggers = new ArrayList<>(); List monitorIds = new ArrayList<>(); + Map rulePerMonitor = new HashMap<>(); + String ruleIndex = null; String alertsIndex = null; String alertsHistoryIndex = null; @@ -390,6 +407,9 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws monitorIds.add(monitorId); } break; + case BUCKET_MONITOR_ID_RULE_ID: + rulePerMonitor= xcp.mapStrings(); + break; case RULE_TOPIC_INDEX: ruleIndex = xcp.text(); break; @@ -437,7 +457,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws alertsHistoryIndex, alertsHistoryIndexPattern, findingsIndex, - findingsIndexPattern); + findingsIndexPattern, + rulePerMonitor); } public static Detector readFrom(StreamInput sin) throws IOException { @@ -516,6 +537,8 @@ public List getMonitorIds() { return monitorIds; } + public Map getRuleIdMonitorId() {return ruleIdMonitorId; } + public void setId(String id) { this.id = id; } @@ -563,6 +586,13 @@ public void setInputs(List inputs) { public void setMonitorIds(List monitorIds) { this.monitorIds = monitorIds; } + public void setRuleIdMonitorId(Map ruleIdMonitorId) { + this.ruleIdMonitorId = ruleIdMonitorId; + } + + public String getDocLevelMonitorId() { + return ruleIdMonitorId.get(DOC_LEVEL_MONITOR); + } @Override public boolean equals(Object o) { diff --git a/src/main/java/org/opensearch/securityanalytics/model/Rule.java b/src/main/java/org/opensearch/securityanalytics/model/Rule.java index 4fc6cfc05..9c269ece7 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Rule.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Rule.java @@ -452,7 +452,6 @@ public boolean isAggregationRule() { return aggregationQueries != null && !aggregationQueries.isEmpty(); } - // TODO - temp method; Replace once you have some more inputs from Shubo and Surya public List getAggregationItemsFromRule () throws SigmaError { SigmaRule sigmaRule = SigmaRule.fromYaml(rule, true); List aggregationItems = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java index baa73e30e..df6359f33 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/AggregationBuilders.java @@ -1,76 +1,29 @@ package org.opensearch.securityanalytics.rules.backend; import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.histogram.VariableWidthHistogramAggregationBuilder; -import org.opensearch.search.aggregations.bucket.range.DateRangeAggregationBuilder; -import org.opensearch.search.aggregations.bucket.range.GeoDistanceAggregationBuilder; -import org.opensearch.search.aggregations.bucket.range.IpRangeAggregationBuilder; -import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; -import org.opensearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; -import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; -import org.opensearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.GeoCentroidAggregationBuilder; import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; import org.opensearch.search.aggregations.metrics.MedianAbsoluteDeviationAggregationBuilder; import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; -import org.opensearch.search.aggregations.metrics.PercentileRanksAggregationBuilder; -import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder; -import org.opensearch.search.aggregations.metrics.StatsAggregationBuilder; import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; + public final class AggregationBuilders { /** * Finds the builder aggregation based on the forwarded function * - * @param aggregationFunction - aggregation function - * @param name - name of the aggregation - * @return + * @param aggregationFunction Aggregation function + * @param name Name of the aggregation + * @return Aggregation builder */ - public static AggregationBuilder getAggregationBuilderByFunction(String aggregationFunction, String name){ + public static AggregationBuilder getAggregationBuilderByFunction(String aggregationFunction, String name) { AggregationBuilder aggregationBuilder; - switch (aggregationFunction){ - case AutoDateHistogramAggregationBuilder.NAME: - aggregationBuilder = new AutoDateHistogramAggregationBuilder(name).field(name); - break; + switch (aggregationFunction.toLowerCase()) { case AvgAggregationBuilder.NAME: aggregationBuilder = new AvgAggregationBuilder(name).field(name); break; - case CardinalityAggregationBuilder.NAME: - aggregationBuilder = new CardinalityAggregationBuilder(name).field(name); - break; - case DateHistogramAggregationBuilder.NAME: - aggregationBuilder = new DateHistogramAggregationBuilder(name).field(name); - break; - case DateRangeAggregationBuilder.NAME: - aggregationBuilder = new DateRangeAggregationBuilder(name).field(name); - break; - case DiversifiedAggregationBuilder.NAME: - aggregationBuilder = new DiversifiedAggregationBuilder(name).field(name); - break; - case ExtendedStatsAggregationBuilder.NAME: - aggregationBuilder = new ExtendedStatsAggregationBuilder(name).field(name); - break; - case GeoCentroidAggregationBuilder.NAME: - aggregationBuilder = new GeoCentroidAggregationBuilder(name).field(name); - break; - // TODO ? - case GeoDistanceAggregationBuilder.NAME: - aggregationBuilder = new GeoDistanceAggregationBuilder(name, null).field(name); - break; - case HistogramAggregationBuilder.NAME: - aggregationBuilder = new HistogramAggregationBuilder(name).field(name); - break; - case IpRangeAggregationBuilder.NAME: - aggregationBuilder = new IpRangeAggregationBuilder(name).field(name); - break; case MaxAggregationBuilder.NAME: aggregationBuilder = new MaxAggregationBuilder(name).field(name); break; @@ -80,38 +33,17 @@ public static AggregationBuilder getAggregationBuilderByFunction(String aggregat case MinAggregationBuilder.NAME: aggregationBuilder = new MinAggregationBuilder(name).field(name); break; - // TODO - do we need this? - case PercentileRanksAggregationBuilder.NAME: - aggregationBuilder = new PercentileRanksAggregationBuilder(name, null).field(name); - break; - case PercentilesAggregationBuilder.NAME: - aggregationBuilder = new PercentilesAggregationBuilder(name).field(name); - break; - case RangeAggregationBuilder.NAME: - aggregationBuilder = new RangeAggregationBuilder(name).field(name); - break; - case RareTermsAggregationBuilder.NAME: - aggregationBuilder = new RareTermsAggregationBuilder(name).field(name); - break; - case SignificantTermsAggregationBuilder.NAME: - aggregationBuilder = new SignificantTermsAggregationBuilder(name).field(name); - break; - case StatsAggregationBuilder.NAME: - aggregationBuilder = new StatsAggregationBuilder(name).field(name); - break; case SumAggregationBuilder.NAME: aggregationBuilder = new SumAggregationBuilder(name).field(name); break; case TermsAggregationBuilder.NAME: aggregationBuilder = new TermsAggregationBuilder(name).field(name); break; - case ValueCountAggregationBuilder.NAME: + case "count": aggregationBuilder = new ValueCountAggregationBuilder(name).field(name); break; - case VariableWidthHistogramAggregationBuilder.NAME: - aggregationBuilder = new VariableWidthHistogramAggregationBuilder(name).field(name); - break; - default: return null; + default: + return null; } return aggregationBuilder; } diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java index 637ffbf26..b243c884c 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java @@ -389,8 +389,11 @@ public AggregationQueries convertAggregation(AggregationItem aggregation) { fmtAggQuery = String.format(Locale.getDefault(), aggQuery, "result_agg", aggregation.getGroupByField(), aggregation.getAggField(), aggregation.getAggFunction(), aggregation.getAggField()); fmtBucketTriggerQuery = String.format(Locale.getDefault(), bucketTriggerQuery, aggregation.getAggField(), aggregation.getAggField(), "result_agg", aggregation.getAggField(), aggregation.getCompOperator(), aggregation.getThreshold()); + // Add subaggregation AggregationBuilder subAgg = AggregationBuilders.getAggregationBuilderByFunction(aggregation.getAggFunction(), aggregation.getAggField()); - aggBuilder.field(aggregation.getGroupByField()).subAggregation(subAgg); + if (subAgg != null) { + aggBuilder.field(aggregation.getGroupByField()).subAggregation(subAgg); + } Script script = new Script(String.format(Locale.getDefault(), bucketTriggerScript, aggregation.getAggField(), aggregation.getCompOperator(), aggregation.getThreshold())); condition = new BucketSelectorExtAggregationBuilder(bucketTriggerSelectorId, Collections.singletonMap(aggregation.getAggField(), aggregation.getAggField()), script, "result_agg", null); diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 004b7a7e5..37855e708 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -4,11 +4,11 @@ */ package org.opensearch.securityanalytics.transport; -import static java.util.Collections.emptyList; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -20,9 +20,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRunnable; +import org.opensearch.action.StepListener; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.get.GetRequest; @@ -32,6 +34,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.WriteRequest.RefreshPolicy; @@ -39,23 +42,20 @@ import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.ToXContent; -import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.DeleteMonitorRequest; +import org.opensearch.commons.alerting.action.DeleteMonitorResponse; import org.opensearch.commons.alerting.action.IndexMonitorRequest; import org.opensearch.commons.alerting.action.IndexMonitorResponse; import org.opensearch.commons.alerting.model.BucketLevelTrigger; @@ -72,11 +72,11 @@ import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; import org.opensearch.rest.RestStatus; import org.opensearch.script.Script; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.action.IndexDetectorAction; import org.opensearch.securityanalytics.action.IndexDetectorRequest; @@ -127,8 +127,10 @@ public class TransportIndexDetectorAction extends HandledTransportAction> rulesById, Detector detector, ActionListener listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { + private void createMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener>listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( Collectors.toList()); List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( Collectors.toList()); - createAlertingMonitorFromQueries(Pair.of(index, docLevelRules), detector, listener, refreshPolicy); - createBucketMonitorFromQueries(Pair.of(index, bucketLevelRules), detector, listener, refreshPolicy); + List monitorRequests = new ArrayList<>(); + + if (!docLevelRules.isEmpty()) { + monitorRequests.add(createDocLevelMonitorRequest(Pair.of(index, docLevelRules), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + } + if (!bucketLevelRules.isEmpty()) { + monitorRequests.addAll(buildBucketLevelMonitorRequests(Pair.of(index, bucketLevelRules), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + } + + GroupedActionListener monitorResponseListener = new GroupedActionListener( + new ActionListener>() { + @Override + public void onResponse(Collection indexMonitorResponse) { + listener.onResponse(indexMonitorResponse.stream().collect(Collectors.toList())); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, monitorRequests.size()); + + // Persist monitors sequentially + for (IndexMonitorRequest req: monitorRequests) { + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, req, namedWriteableRegistry, monitorResponseListener); + } + } + + private void updateMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws SigmaError, IOException { + List monitorsToBeUpdated = new ArrayList<>(); + + List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( + Collectors.toList()); + List monitorsToBeAdded = new ArrayList<>(); + // Process bucket level monitors + if (!bucketLevelRules.isEmpty()) { + List ruleCategories = bucketLevelRules.stream().map(Pair::getRight).map(Rule::getCategory).distinct().collect( + Collectors.toList()); + Map queryBackendMap = new HashMap<>(); + for(String category: ruleCategories){ + queryBackendMap.put(category, new OSQueryBackend(category, true, true)); + } + + // Pair of RuleId - MonitorId for existing monitors of the detector + Map monitorPerRule = detector.getRuleIdMonitorId(); + + for (Pair query: bucketLevelRules) { + Rule rule = query.getRight(); + if(rule.getAggregationQueries() != null){ + // Detect if the monitor should be added or updated + if (monitorPerRule.containsKey(rule.getId())) { + String monitorId = monitorPerRule.get(rule.getId()); + monitorsToBeUpdated.add(createBucketLevelMonitorRequest(query.getRight(), + index, + detector, + refreshPolicy, + monitorId, + Method.PUT, + queryBackendMap.get(rule.getCategory()))); + } else { + monitorsToBeAdded.add(createBucketLevelMonitorRequest(query.getRight(), + index, + detector, + refreshPolicy, + Monitor.NO_ID, + Method.POST, + queryBackendMap.get(rule.getCategory()))); + } + } + } + } + + List bucketMonitorIdsToBeDeleted = detector.getRuleIdMonitorId().values().stream().collect(Collectors.toList()); + bucketMonitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( + Collectors.toList())); + + List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( + Collectors.toList()); + + // Process doc level monitors + if (!docLevelRules.isEmpty()) { + if (detector.getDocLevelMonitorId() == null) { + monitorsToBeAdded.add(createDocLevelMonitorRequest(Pair.of(index, docLevelRules), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + } else { + monitorsToBeUpdated.add(createDocLevelMonitorRequest(Pair.of(index, docLevelRules), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT)); + } + } else { + if(detector.getDocLevelMonitorId() != null) { + bucketMonitorIdsToBeDeleted.add(detector.getDocLevelMonitorId()); + } + } + + updateAlertingMonitors(monitorsToBeAdded, monitorsToBeUpdated, bucketMonitorIdsToBeDeleted, refreshPolicy, listener); + } + + /** + * Update list of monitors for the given detector + * Executed in a steps: + * 1. Add new monitors; + * 2. Update existing monitors; + * 3. Delete the monitors omitted from request + * 4. Respond with updated list of monitors + * @param monitorsToBeAdded Newly added monitors by the user + * @param monitorsToBeUpdated Existing monitors that will be updated + * @param monitorsToBeDeleted Monitors omitted by the user + * @param refreshPolicy + * @param listener Listener that accepts the list of updated monitors if the action was successful + */ + private void updateAlertingMonitors( + List monitorsToBeAdded, + List monitorsToBeUpdated, + List monitorsToBeDeleted, + RefreshPolicy refreshPolicy, + ActionListener> listener + ) { + List updatedMonitors = new ArrayList<>(); + + // Update monitor steps + StepListener> addNewMonitorsStep = new StepListener(); + executeMonitorActionRequest(monitorsToBeAdded, addNewMonitorsStep); + // 1. Add new alerting bucket monitors (for the rules that didn't exist previously) + addNewMonitorsStep.whenComplete(addNewMonitorsResponse -> { + updatedMonitors.addAll(addNewMonitorsResponse); + StepListener> updateMonitorsStep = new StepListener<>(); + executeMonitorActionRequest(monitorsToBeUpdated, updateMonitorsStep); + // 2. Update existing bucket alerting monitors (based on the common rules) + updateMonitorsStep.whenComplete(updateMonitorResponse -> { + updatedMonitors.addAll(updateMonitorResponse); + StepListener> deleteMonitorStep = new StepListener<>(); + deleteAlertingMonitors(monitorsToBeDeleted, refreshPolicy, deleteMonitorStep); + // 3. Delete bucket alerting monitors (rules that are not provided by the user) + deleteMonitorStep.whenComplete(deleteMonitorResponses -> + // Return list of all updated + newly added monitors + listener.onResponse(updatedMonitors), + // Handle delete monitors (step 3) + listener::onFailure); + }, // Handle update monitor failed (step 2) + listener::onFailure); + // Handle add failed (step 1) + }, listener::onFailure); } - private void createAlertingMonitorFromQueries(Pair>> logIndexToQueries, Detector detector, ActionListener listener, WriteRequest.RefreshPolicy refreshPolicy) { + private IndexMonitorRequest createDocLevelMonitorRequest(Pair>> logIndexToQueries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) { List docLevelMonitorInputs = new ArrayList<>(); List docLevelQueries = new ArrayList<>(); @@ -196,154 +336,174 @@ private void createAlertingMonitorFromQueries(Pair>> logIndexToQueries, Detector detector, ActionListener listener, WriteRequest.RefreshPolicy refreshPolicy) throws IOException, SigmaError { - // TODO - think about the smarter way - // Prepare the queryBackend instances per rule category + private List buildBucketLevelMonitorRequests(Pair>> logIndexToQueries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) throws IOException, SigmaError { List ruleCategories = logIndexToQueries.getRight().stream().map(Pair::getRight).map(Rule::getCategory).distinct().collect( Collectors.toList()); Map queryBackendMap = new HashMap<>(); + for(String category: ruleCategories){ queryBackendMap.put(category, new OSQueryBackend(category, true, true)); } + List monitorRequests = new ArrayList<>(); + for (Pair query: logIndexToQueries.getRight()) { Rule rule = query.getRight(); + // Creating bucket level monitor per each aggregation rule - // TODO - check if bucket level monitors needs to be created per rule? if(rule.getAggregationQueries() != null){ - AggregationQueries aggregationQueries = queryBackendMap.get(rule.getCategory()).convertAggregation(rule.getAggregationItemsFromRule().get(0)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .seqNoAndPrimaryTerm(true) - .version(true) - // Build query string filter - .query(QueryBuilders.queryStringQuery(rule.getQueries().get(0).getValue())) - .aggregation(aggregationQueries.getAggBuilder()) - .size(10000); - - List bucketLevelMonitorInputs = new ArrayList<>(); - bucketLevelMonitorInputs.add(new SearchInput(Arrays.asList(logIndexToQueries.getKey()), searchSourceBuilder)); - - List detectorTriggers = detector.getTriggers(); - List triggers = new ArrayList<>(); - - for (DetectorTrigger detectorTrigger: detectorTriggers) { - String id = detectorTrigger.getId(); - String name = detectorTrigger.getName(); - String severity = detectorTrigger.getSeverity(); - List actions = detectorTrigger.getActions(); - BucketLevelTrigger bucketLevelTrigger = new BucketLevelTrigger(id, name, severity, aggregationQueries.getCondition(), actions); - triggers.add(bucketLevelTrigger); - } - - Monitor monitor = new Monitor(Monitor.NO_ID, Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), - MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), new DataSources()); - - // TODO - remove after figuring out how to serde monitor request - testSerde(refreshPolicy, monitor); - - IndexMonitorRequest indexMonitorRequest = new IndexMonitorRequest(Monitor.NO_ID, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, RestRequest.Method.POST, monitor); - AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, indexMonitorRequest, listener); + monitorRequests.add(createBucketLevelMonitorRequest( + query.getRight(), + logIndexToQueries.getLeft(), + detector, + refreshPolicy, + Monitor.NO_ID, + Method.POST, + queryBackendMap.get(rule.getCategory()))); } } + return monitorRequests; } - // TODO - delete the method after figuring out - private static void testSerde( - RefreshPolicy refreshPolicy, - Monitor monitor - ) throws IOException { - XContentBuilder builder = monitor.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - String monitorAsString = BytesReference.bytes(builder).utf8ToString(); - - final SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); - final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables()); - - BytesStreamOutput out = new BytesStreamOutput(); - monitor.writeTo(out); - Monitor newReq; - try (StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry)) { - newReq = new Monitor(in); - } - log.info(newReq); - - IndexMonitorRequest indexMonitorRequest = new IndexMonitorRequest(Monitor.NO_ID, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, - refreshPolicy, RestRequest.Method.POST, - monitor); - BytesStreamOutput out1 = new BytesStreamOutput(); - indexMonitorRequest.writeTo(out1); - IndexMonitorRequest req; - try (StreamInput in = new NamedWriteableAwareStreamInput(out1.bytes().streamInput(), namedWriteableRegistry)) { - req = new IndexMonitorRequest(in); - } - - log.info(req); - log.info(monitorAsString); + private IndexMonitorRequest createBucketLevelMonitorRequest( + Rule rule, + String index, + Detector detector, + WriteRequest.RefreshPolicy refreshPolicy, + String monitorId, + RestRequest.Method restMethod, + QueryBackend queryBackend + ) throws SigmaError { + AggregationQueries aggregationQueries = queryBackend.convertAggregation(rule.getAggregationItemsFromRule().get(0)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .seqNoAndPrimaryTerm(true) + .version(true) + // Build query string filter + .query(QueryBuilders.queryStringQuery(rule.getQueries().get(0).getValue())) + .aggregation(aggregationQueries.getAggBuilder()) + .size(10000); + + List bucketLevelMonitorInputs = new ArrayList<>(); + bucketLevelMonitorInputs.add(new SearchInput(Arrays.asList(index), searchSourceBuilder)); + + List triggers = new ArrayList<>(); + BucketLevelTrigger bucketLevelTrigger = new BucketLevelTrigger(rule.getId(), rule.getTitle(), rule.getLevel(), aggregationQueries.getCondition(), + Collections.emptyList()); + triggers.add(bucketLevelTrigger); + + /** TODO - Think how to use detector trigger + List detectorTriggers = detector.getTriggers(); + for (DetectorTrigger detectorTrigger: detectorTriggers) { + String id = detectorTrigger.getId(); + String name = detectorTrigger.getName(); + String severity = detectorTrigger.getSeverity(); + List actions = detectorTrigger.getActions(); + Script condition = detectorTrigger.convertToCondition(); + + BucketLevelTrigger bucketLevelTrigger1 = new BucketLevelTrigger(id, name, severity, condition, actions); + triggers.add(bucketLevelTrigger1); + } **/ + + Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), + MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), + new DataSources(detector.getRuleIndex(), + detector.getFindingsIndex(), + detector.getFindingsIndexPattern(), + detector.getAlertsIndex(), + detector.getAlertsHistoryIndex(), + detector.getAlertsHistoryIndexPattern(), + DetectorMonitorConfig.getRuleIndexMappingsByType(detector.getDetectorType()))); + + return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor); } - private void updateAlertingMonitorFromQueries(Pair>> logIndexToQueries, Detector detector, ActionListener listener, WriteRequest.RefreshPolicy refreshPolicy) { - List docLevelMonitorInputs = new ArrayList<>(); - - List docLevelQueries = new ArrayList<>(); - - for (Pair query: logIndexToQueries.getRight()) { - String id = query.getLeft(); - - Rule rule = query.getRight(); - String name = query.getLeft(); - - String actualQuery = rule.getQueries().get(0).getValue(); - - List tags = new ArrayList<>(); - tags.add(rule.getLevel()); - tags.add(rule.getCategory()); - tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList())); - - DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, actualQuery, tags); - docLevelQueries.add(docLevelQuery); + /** + * Executes monitor related requests (PUT/POST) - returns the response once all the executions are completed + * @param indexMonitors Monitors to be updated/added + * @param listener actionListener for handling updating/creating monitors + */ + public void executeMonitorActionRequest( + List indexMonitors, + ActionListener> listener) { + + // In the case of not provided monitors, just return empty list + if(indexMonitors == null || indexMonitors.size() == 0) { + listener.onResponse(new ArrayList<>()); + return; } - DocLevelMonitorInput docLevelMonitorInput = new DocLevelMonitorInput(detector.getName(), List.of(logIndexToQueries.getKey()), docLevelQueries); - docLevelMonitorInputs.add(docLevelMonitorInput); - List triggers = new ArrayList<>(); - List detectorTriggers = detector.getTriggers(); + GroupedActionListener monitorResponseListener = new GroupedActionListener( + new ActionListener>() { + @Override + public void onResponse(Collection indexMonitorResponse) { + listener.onResponse(indexMonitorResponse.stream().collect(Collectors.toList())); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, indexMonitors.size()); - for (DetectorTrigger detectorTrigger: detectorTriggers) { - String id = detectorTrigger.getId(); - String name = detectorTrigger.getName(); - String severity = detectorTrigger.getSeverity(); - List actions = detectorTrigger.getActions(); - Script condition = detectorTrigger.convertToCondition(); + // Persist monitors sequentially + for (IndexMonitorRequest req: indexMonitors) { + AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, req, namedWriteableRegistry, monitorResponseListener); + } + } - triggers.add(new DocumentLevelTrigger(id, name, severity, actions, condition)); + /** + * Deletes the alerting monitors based on the given ids and notifies the listener that will be notified once all monitors have been deleted + * @param monitorIds monitor ids to be deleted + * @param refreshPolicy + * @param listener listener that will be notified once all the monitors are being deleted + */ + private void deleteAlertingMonitors(List monitorIds, WriteRequest.RefreshPolicy refreshPolicy, ActionListener> listener){ + if (monitorIds == null || monitorIds.isEmpty()) { + listener.onResponse(new ArrayList<>()); + return; } + ActionListener deletesListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection responses) { + SetOnce errorStatusSupplier = new SetOnce<>(); + if (responses.stream().filter(response -> { + if (response.getStatus() != RestStatus.OK) { + log.error("Monitor [{}] could not be deleted. Status [{}]", response.getId(), response.getStatus()); + errorStatusSupplier.trySet(response.getStatus()); + return true; + } + return false; + }).count() > 0) { + listener.onFailure(new OpenSearchStatusException("Monitor associated with detected could not be deleted", errorStatusSupplier.get())); + } + listener.onResponse(responses.stream().collect(Collectors.toList())); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, monitorIds.size()); - Monitor monitor = new Monitor(detector.getMonitorIds().get(0), Monitor.NO_VERSION, detector.getName(), detector.getEnabled(), detector.getSchedule(), detector.getLastUpdateTime(), detector.getEnabledTime(), - Monitor.MonitorType.DOC_LEVEL_MONITOR, detector.getUser(), 1, docLevelMonitorInputs, triggers, Map.of(), - new DataSources(detector.getRuleIndex(), - detector.getFindingsIndex(), - detector.getFindingsIndexPattern(), - detector.getAlertsIndex(), - detector.getAlertsHistoryIndex(), - detector.getAlertsHistoryIndexPattern(), - DetectorMonitorConfig.getRuleIndexMappingsByType(detector.getDetectorType()))); - - IndexMonitorRequest indexMonitorRequest = new IndexMonitorRequest(detector.getMonitorIds().get(0), SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, RestRequest.Method.PUT, monitor); - AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, indexMonitorRequest, listener); + for (String monitorId : monitorIds) { + deleteAlertingMonitor(monitorId, refreshPolicy, deletesListener); + } + } + private void deleteAlertingMonitor(String monitorId, WriteRequest.RefreshPolicy refreshPolicy, ActionListener listener) { + DeleteMonitorRequest request = new DeleteMonitorRequest(monitorId, refreshPolicy); + AlertingPluginInterface.INSTANCE.deleteMonitor((NodeClient) client, request, listener); } private void onCreateMappingsResponse(CreateIndexResponse response) throws IOException { @@ -458,8 +618,9 @@ public void onResponse(CreateIndexResponse createIndexResponse) { initRuleIndexAndImportRules(request, new ActionListener<>() { @Override - public void onResponse(IndexMonitorResponse indexMonitorResponse) { - request.getDetector().setMonitorIds(List.of(indexMonitorResponse.getId())); + public void onResponse(List monitorResponses) { + request.getDetector().setMonitorIds(getMonitorIds(monitorResponses)); + request.getDetector().setRuleIdMonitorId(mapMonitorIds(monitorResponses)); try { indexDetector(); } catch (IOException e) { @@ -522,6 +683,7 @@ void onGetResponse(Detector currentDetector) { request.getDetector().setEnabledTime(currentDetector.getEnabledTime()); } request.getDetector().setMonitorIds(currentDetector.getMonitorIds()); + request.getDetector().setRuleIdMonitorId(currentDetector.getRuleIdMonitorId()); Detector detector = request.getDetector(); String ruleTopic = detector.getDetectorType(); @@ -540,8 +702,9 @@ void onGetResponse(Detector currentDetector) { public void onResponse(CreateIndexResponse createIndexResponse) { initRuleIndexAndImportRules(request, new ActionListener<>() { @Override - public void onResponse(IndexMonitorResponse indexMonitorResponse) { - request.getDetector().setMonitorIds(List.of(indexMonitorResponse.getId())); + public void onResponse(List monitorResponses) { + request.getDetector().setMonitorIds(getMonitorIds(monitorResponses)); + request.getDetector().setRuleIdMonitorId(mapMonitorIds(monitorResponses)); try { indexDetector(); } catch (IOException e) { @@ -567,7 +730,7 @@ public void onFailure(Exception e) { } } - public void initRuleIndexAndImportRules(IndexDetectorRequest request, ActionListener listener) { + public void initRuleIndexAndImportRules(IndexDetectorRequest request, ActionListener> listener) { ruleIndices.initPrepackagedRulesIndex( new ActionListener<>() { @Override @@ -672,7 +835,7 @@ public void onFailure(Exception e) { } @SuppressWarnings("unchecked") - public void importRules(IndexDetectorRequest request, ActionListener listener) { + public void importRules(IndexDetectorRequest request, ActionListener> listener) { final Detector detector = request.getDetector(); final String ruleTopic = detector.getDetectorType(); final DetectorInput detectorInput = detector.getInputs().get(0); @@ -725,12 +888,10 @@ public void onResponse(SearchResponse response) { } else if (detectorInput.getCustomRules().size() > 0) { onFailures(new OpenSearchStatusException("Custom Rule Index not found", RestStatus.BAD_REQUEST)); } else { - Pair>> logIndexToQueries = Pair.of(logIndex, queries); - if (request.getMethod() == RestRequest.Method.POST) { createMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } else if (request.getMethod() == RestRequest.Method.PUT) { - updateAlertingMonitorFromQueries(logIndexToQueries, detector, listener, request.getRefreshPolicy()); + updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } } } catch (IOException | SigmaError e) { @@ -746,7 +907,7 @@ public void onFailure(Exception e) { } @SuppressWarnings("unchecked") - public void importCustomRules(Detector detector, DetectorInput detectorInput, List> queries, ActionListener listener) { + public void importCustomRules(Detector detector, DetectorInput detectorInput, List> queries, ActionListener> listener) { final String logIndex = detectorInput.getIndices().get(0); List ruleIds = detectorInput.getCustomRules().stream().map(DetectorRule::getId).collect(Collectors.toList()); @@ -780,12 +941,10 @@ public void onResponse(SearchResponse response) { queries.add(Pair.of(id, rule)); } - Pair>> logIndexToQueries = Pair.of(logIndex, queries); - if (request.getMethod() == RestRequest.Method.POST) { createMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } else if (request.getMethod() == RestRequest.Method.PUT) { - updateAlertingMonitorFromQueries(logIndexToQueries, detector, listener, request.getRefreshPolicy()); + updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); } } catch (IOException | SigmaError ex) { onFailures(ex); @@ -851,5 +1010,32 @@ private void finishHim(Detector detector, Exception t) { } })); } + + private List getMonitorIds(List monitorResponses) { + return monitorResponses.stream().map(IndexMonitorResponse::getId).collect( + Collectors.toList()); + } + + /** + * Creates a map of monitor ids. In the case of bucket level monitors pairs are: RuleId - MonitorId + * In the case of doc level monitor pair is DOC_LEVEL_MONITOR(value) - MonitorId + * @param monitorResponses index monitor responses + * @return map of monitor ids + */ + private Map mapMonitorIds(List monitorResponses) { + return monitorResponses.stream().collect( + Collectors.toMap( + // In the case of bucket level monitors rule id is trigger id + it -> { + if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { + return it.getMonitor().getTriggers().get(0).getId(); + } else { + return Detector.DOC_LEVEL_MONITOR; + } + }, + IndexMonitorResponse::getId + ) + ); + } } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index a3511b862..84c15c12e 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics; +import java.util.ArrayList; import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.entity.ContentType; @@ -52,6 +53,8 @@ import java.util.stream.Collectors; import static org.opensearch.action.admin.indices.create.CreateIndexRequest.MAPPINGS; +import static org.opensearch.securityanalytics.TestHelpers.sumAggregationTestRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexAvgAggRule; public class SecurityAnalyticsRestTestCase extends OpenSearchRestTestCase { @@ -186,6 +189,18 @@ protected List getRandomPrePackagedRules() throws IOException { return hits.stream().map(hit -> hit.get("_id").toString()).collect(Collectors.toList()); } + protected List createAggregationRules () throws IOException { + return new ArrayList<>(Arrays.asList(createRule(productIndexAvgAggRule()), createRule(sumAggregationTestRule()))); + } + + protected String createRule(String rule) throws IOException { + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "windows"), + new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + return responseBody.get("_id").toString(); + } + protected List getPrePackagedRules(String ruleCategory) throws IOException { String request = "{\n" + " \"from\": 0\n," + diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index 6b863f710..af5ed9a82 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -37,7 +37,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; import java.util.stream.Collectors; import static org.opensearch.test.OpenSearchTestCase.randomInt; @@ -121,7 +120,7 @@ public static Detector randomDetector(String name, DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of("windows"), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of()); triggers.add(trigger); } - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", ""); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap()); } public static Detector randomDetectorWithNoUser() { @@ -133,7 +132,7 @@ public static Detector randomDetectorWithNoUser() { Instant enabledTime = enabled ? Instant.now().truncatedTo(ChronoUnit.MILLIS) : null; Instant lastUpdateTime = Instant.now().truncatedTo(ChronoUnit.MILLIS); - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, null, inputs, Collections.emptyList(),Collections.singletonList(""), "", "", "", "", "", ""); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, null, inputs, Collections.emptyList(),Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap()); } public static String randomRule() { @@ -164,6 +163,7 @@ public static String randomRule() { " - Legitimate usage of remote file encryption\n" + "level: high"; } + public static String countAggregationTestRule() { return " title: Test\n" + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + @@ -183,7 +183,7 @@ public static String countAggregationTestRule() { " condition: sel | count(*) > 1"; } - public static String avgAggregationTestRule() { + public static String sumAggregationTestRule() { return " title: Test\n" + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + " status: test\n" + @@ -196,10 +196,37 @@ public static String avgAggregationTestRule() { " product: test_product\n" + " detection:\n" + " sel:\n" + - " fieldA: valueA\n" + - " fieldB: valueB\n" + + " fieldA: 123\n" + + " fieldB: 111\n" + + " fieldC: valueC\n" + + " condition: sel | sum(fieldA) by fieldB > 110"; + } + + public static String productIndexMaxAggRule() { + return " title: Test\n" + + " id: 5f92fff9-82e3-48eb-8fc1-8b133556a551\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA: 123\n" + + " fieldB: 111\n" + " fieldC: valueC\n" + - " condition: sel | avg(fieldA) by fieldB > 110"; + " condition: sel | max(fieldA) by fieldB > 110"; + } + + public static String randomProductDocument(){ + return "{\n" + + " \"fieldA\": 123,\n" + + " \"mappedB\": 111,\n" + + " \"fieldC\": \"valueC\"\n" + + "}\n"; } public static String randomEditedRule() { @@ -394,6 +421,40 @@ public static String netFlowMappings() { " }"; } + public static String productIndexMapping(){ + return "\"properties\":{\n" + + " \"fieldA\":{\n" + + " \"type\":\"long\"\n" + + " },\n" + + " \"mappedB\":{\n" + + " \"type\":\"long\"\n" + + " },\n" + + " \"fieldC\":{\n" + + " \"type\":\"keyword\"\n" + + " }\n" + + "}\n" + + "}"; + } + + public static String productIndexAvgAggRule(){ + return " title: Test\n" + + " id: 39f918f3-981b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA: 123\n" + + " fieldB: 111\n" + + " fieldC: valueC\n" + + " condition: sel | avg(fieldA) by fieldC > 110"; + } + public static String windowsIndexMapping() { return "\"properties\": {\n" + " \"AccessList\": {\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java index 84f930d1b..ad6a110e2 100644 --- a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java +++ b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java @@ -49,7 +49,8 @@ public void testIndexDetectorPostResponse() throws IOException { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); IndexDetectorResponse response = new IndexDetectorResponse("1234", 1L, RestStatus.OK, detector); Assert.assertNotNull(response); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java index 4a061525a..29dc741ea 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java @@ -7,6 +7,7 @@ import java.time.Instant; import java.time.ZoneId; +import java.util.Collections; import java.util.List; import java.util.Map; import org.opensearch.action.ActionListener; @@ -61,7 +62,8 @@ public void testGetAlerts_success() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -230,7 +232,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index c5c0cb425..6ad0b5a14 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -8,8 +8,10 @@ import java.time.Instant; import java.time.ZoneId; import java.util.ArrayDeque; +import java.util.Collections; import java.util.List; import java.util.Queue; +import java.util.stream.Collectors; import org.opensearch.action.ActionListener; import org.opensearch.client.Client; import org.opensearch.commons.alerting.model.CronSchedule; @@ -61,7 +63,8 @@ public void testGetFindings_success() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -176,7 +179,8 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { DetectorMonitorConfig.getAlertsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), null, null, - DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()) + DetectorMonitorConfig.getFindingsIndex(Detector.DetectorType.OTHERS_APPLICATION.getDetectorType()), + Collections.emptyMap() ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java index 668436146..50dca7cb2 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java @@ -4,6 +4,9 @@ */ package org.opensearch.securityanalytics.resthandler; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import org.apache.http.HttpEntity; import org.apache.http.HttpStatus; import org.apache.http.entity.ContentType; @@ -14,6 +17,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.commons.alerting.model.Monitor.MonitorType; import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchHit; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; @@ -29,13 +33,18 @@ import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; +import org.opensearch.securityanalytics.model.Rule; -import static org.opensearch.securityanalytics.TestHelpers.avgAggregationTestRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexMaxAggRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexAvgAggRule; +import static org.opensearch.securityanalytics.TestHelpers.productIndexMapping; import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; +import static org.opensearch.securityanalytics.TestHelpers.randomProductDocument; import static org.opensearch.securityanalytics.TestHelpers.randomRule; +import static org.opensearch.securityanalytics.TestHelpers.sumAggregationTestRule; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; public class DetectorRestApiIT extends SecurityAnalyticsRestTestCase { @@ -159,7 +168,7 @@ public void testSearchingDetectors() throws IOException { Map searchResponseTotal = (Map) searchResponseHits.get("total"); Assert.assertEquals(1, searchResponseTotal.get("value")); } - + @SuppressWarnings("unchecked") public void testCreatingADetectorWithCustomRules() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping()); @@ -226,7 +235,7 @@ public void testCreatingADetectorWithCustomRules() throws IOException { } public void testCreatingADetectorWithAggregationRules() throws IOException { - String index = createTestIndex(randomIndex(), windowsIndexMapping()); + String index = createTestIndex(randomIndex(), productIndexMapping()); // Execute CreateMappingsAction to add alias mapping for index Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); @@ -241,7 +250,7 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { Response response = client().performRequest(createMappingRequest); assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); - String rule = avgAggregationTestRule(); + String rule = productIndexAvgAggRule(); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "windows"), new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); @@ -249,9 +258,9 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { Map responseBody = asMap(createResponse); - String createdId = responseBody.get("_id").toString(); + String createdRuleId = responseBody.get("_id").toString(); - DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdRuleId)), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector = randomDetectorWithInputs(List.of(input)); @@ -260,11 +269,11 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { responseBody = asMap(createResponse); - createdId = responseBody.get("_id").toString(); + createdRuleId = responseBody.get("_id").toString(); int createdVersion = Integer.parseInt(responseBody.get("_version").toString()); - Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdId); + Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdRuleId); Assert.assertTrue("incorrect version", createdVersion > 0); - Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, createdId), createResponse.getHeader("Location")); + Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, createdRuleId), createResponse.getHeader("Location")); Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("rule_topic_index")); Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("findings_index")); Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("alert_index")); @@ -272,24 +281,46 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { String request = "{\n" + " \"query\" : {\n" + " \"match\":{\n" + - " \"_id\": \"" + createdId + "\"\n" + + " \"_id\": \"" + createdRuleId + "\"\n" + " }\n" + " }\n" + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); - String monitorId = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + List monitorTypes = new ArrayList<>(); - indexDoc(index, "1", randomDoc()); + Map detectorAsMap = (Map) hit.getSourceAsMap().get("detector"); - Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); - Map executeResults = entityAsMap(executeResponse); + String bucketLevelMonitorId = ""; - int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); - Assert.assertEquals(6, noOfSigmaRuleMatches); - } + // Verify that doc level monitor is created + List monitorIds = (List) (detectorAsMap).get("monitor_id"); + + String firstMonitorId = monitorIds.get(0); + String firstMonitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + firstMonitorId))).get("monitor")).get("monitor_type"); + + if(MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(firstMonitorType)){ + bucketLevelMonitorId = firstMonitorId; + } + monitorTypes.add(firstMonitorType); + + String secondMonitorId = monitorIds.get(1); + String secondMonitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + secondMonitorId))).get("monitor")).get("monitor_type"); + monitorTypes.add(secondMonitorType); + if(MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(secondMonitorType)){ + bucketLevelMonitorId = secondMonitorId; + } + Assert.assertTrue(Arrays.asList(MonitorType.BUCKET_LEVEL_MONITOR.getValue(), MonitorType.DOC_LEVEL_MONITOR.getValue()).containsAll(monitorTypes)); + + indexDoc(index, "1", randomProductDocument()); + Response executeResponse = executeAlertingMonitor(bucketLevelMonitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + // Confirm that the monitor execution is correct + int numOfHits = ((ArrayList)((Map) ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).get("hits")).get("hits")).size(); + Assert.assertEquals(1, numOfHits); + } public void testUpdateADetector() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping()); @@ -351,6 +382,186 @@ public void testUpdateADetector() throws IOException { Assert.assertEquals(1580, response.getHits().getTotalHits().value); } + public void testUpdateDetectorAddingNewAggregationRule() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String sumRuleId = createRule(sumAggregationTestRule()); + List detectorRules = List.of(new DetectorRule(sumRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); + Assert.assertEquals(1, response.getHits().getTotalHits().value); + + // Test adding the new max monitor and updating the existing sum monitor + String maxRuleId = createRule(productIndexMaxAggRule()); + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(maxRuleId), new DetectorRule(sumRuleId)), + Collections.emptyList()); + Detector firstUpdatedDetector = randomDetectorWithInputs(List.of(newInput)); + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(firstUpdatedDetector)); + Assert.assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map firstUpdateDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = firstUpdateDetectorMap.get("inputs"); + Assert.assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + } + + public void testUpdateDetectorDeletingExistingAggregationRule() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + List aggRuleIds = createAggregationRules(); + List detectorRules = aggRuleIds.stream().map(DetectorRule::new).collect(Collectors.toList()); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); + Assert.assertEquals(2, response.getHits().getTotalHits().value); + + // Test deleting the aggregation rule + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(aggRuleIds.get(0))), + Collections.emptyList()); + Detector firstUpdatedDetector = randomDetectorWithInputs(List.of(newInput)); + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(firstUpdatedDetector)); + Assert.assertEquals("Update detector failed", RestStatus.OK, restStatus(updateResponse)); + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map firstUpdateDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = firstUpdateDetectorMap.get("inputs"); + Assert.assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + } + + public void testUpdateDetectorWithAggregationAndDocLevelRules() throws IOException { + String index = createTestIndex(randomIndex(), productIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + List aggRuleIds = createAggregationRules(); + List detectorRules = aggRuleIds.stream().map(DetectorRule::new).collect(Collectors.toList()); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); + Assert.assertEquals(2, response.getHits().getTotalHits().value); + + String maxRuleId = createRule(productIndexMaxAggRule()); + + DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(aggRuleIds.get(0)), new DetectorRule(maxRuleId)), + Collections.emptyList()); + + detector = randomDetectorWithInputs(List.of(newInput)); + createResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Update detector failed", RestStatus.OK, restStatus(createResponse)); + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map firstUpdateDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = firstUpdateDetectorMap.get("inputs"); + Assert.assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + } + @SuppressWarnings("unchecked") public void testDeletingADetector() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping());