Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates bucket level monitors for rules containing aggregations #92

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package org.opensearch.securityanalytics.model;

import java.util.HashMap;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.ParseField;
Expand Down Expand Up @@ -49,6 +51,8 @@ public class Detector implements Writeable, ToXContentObject {
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
public static final String ENABLED_TIME_FIELD = "enabled_time";
public static final String ALERTING_MONITOR_ID = "monitor_id";

public static final String BUCKET_MONITOR_ID_RULE_ID = "bucket_monitor_id_rule_id";
private static final String RULE_TOPIC_INDEX = "rule_topic_index";

private static final String ALERTS_INDEX = "alert_index";
Expand All @@ -59,6 +63,9 @@ public class Detector implements Writeable, ToXContentObject {

public static final String DETECTORS_INDEX = ".opensearch-sap-detectors-config";

// Used as a key in rule-monitor map for the purpose of easy detection of the doc level monitor
public static final String DOC_LEVEL_MONITOR = "-1";

public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
Detector.class,
new ParseField(DETECTOR_TYPE),
Expand Down Expand Up @@ -90,6 +97,8 @@ public class Detector implements Writeable, ToXContentObject {

private List<String> monitorIds;

private Map<String, String> ruleIdMonitorIdMap;

private String ruleIndex;

private String alertsIndex;
Expand All @@ -108,7 +117,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule
Instant lastUpdateTime, Instant enabledTime, DetectorType detectorType,
User user, List<DetectorInput> inputs, List<DetectorTrigger> triggers, List<String> monitorIds,
String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern,
String findingsIndex, String findingsIndexPattern) {
String findingsIndex, String findingsIndexPattern, Map<String, String> rulePerMonitor) {
this.type = DETECTOR_TYPE;

this.id = id != null ? id : NO_ID;
Expand All @@ -129,6 +138,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule
this.alertsHistoryIndexPattern = alertsHistoryIndexPattern;
this.findingsIndex = findingsIndex;
this.findingsIndexPattern = findingsIndexPattern;
this.ruleIdMonitorIdMap = rulePerMonitor;

if (enabled) {
Objects.requireNonNull(enabledTime);
Expand All @@ -154,7 +164,9 @@ public Detector(StreamInput sin) throws IOException {
sin.readString(),
sin.readString(),
sin.readString(),
sin.readString());
sin.readString(),
sin.readMap(StreamInput::readString, StreamInput::readString)
);
}

@Override
Expand Down Expand Up @@ -186,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeStringCollection(monitorIds);
out.writeString(ruleIndex);

out.writeMap(ruleIdMonitorIdMap, StreamOutput::writeString, StreamOutput::writeString);
}

public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException {
Expand Down Expand Up @@ -269,6 +283,7 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten
}

builder.field(ALERTING_MONITOR_ID, monitorIds);
builder.field(BUCKET_MONITOR_ID_RULE_ID, ruleIdMonitorIdMap);
builder.field(RULE_TOPIC_INDEX, ruleIndex);
builder.field(ALERTS_INDEX, alertsIndex);
builder.field(ALERTS_HISTORY_INDEX, alertsHistoryIndex);
Expand Down Expand Up @@ -313,6 +328,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws
List<DetectorInput> inputs = new ArrayList<>();
List<DetectorTrigger> triggers = new ArrayList<>();
List<String> monitorIds = new ArrayList<>();
Map<String, String> rulePerMonitor = new HashMap<>();

String ruleIndex = null;
String alertsIndex = null;
String alertsHistoryIndex = null;
Expand Down Expand Up @@ -391,6 +408,9 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws
monitorIds.add(monitorId);
}
break;
case BUCKET_MONITOR_ID_RULE_ID:
rulePerMonitor= xcp.mapStrings();
break;
case RULE_TOPIC_INDEX:
ruleIndex = xcp.text();
break;
Expand Down Expand Up @@ -438,7 +458,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws
alertsHistoryIndex,
alertsHistoryIndexPattern,
findingsIndex,
findingsIndexPattern);
findingsIndexPattern,
rulePerMonitor);
}

public static Detector readFrom(StreamInput sin) throws IOException {
Expand Down Expand Up @@ -521,6 +542,8 @@ public void setUser(User user) {
this.user = user;
}

public Map<String, String> getRuleIdMonitorIdMap() {return ruleIdMonitorIdMap; }

public void setId(String id) {
this.id = id;
}
Expand Down Expand Up @@ -568,6 +591,13 @@ public void setInputs(List<DetectorInput> inputs) {
public void setMonitorIds(List<String> monitorIds) {
this.monitorIds = monitorIds;
}
public void setRuleIdMonitorIdMap(Map<String, String> ruleIdMonitorIdMap) {
this.ruleIdMonitorIdMap = ruleIdMonitorIdMap;
}

public String getDocLevelMonitorId() {
return ruleIdMonitorIdMap.get(DOC_LEVEL_MONITOR);
}

@Override
public boolean equals(Object o) {
Expand Down
55 changes: 49 additions & 6 deletions src/main/java/org/opensearch/securityanalytics/model/Rule.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.securityanalytics.model;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.ParseField;
Expand All @@ -16,6 +17,11 @@
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentParserUtils;
import org.opensearch.securityanalytics.rules.aggregation.AggregationItem;
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries;
import org.opensearch.securityanalytics.rules.condition.ConditionItem;
import org.opensearch.securityanalytics.rules.exceptions.SigmaError;
import org.opensearch.securityanalytics.rules.objects.SigmaCondition;
import org.opensearch.securityanalytics.rules.objects.SigmaRule;

import java.io.IOException;
Expand Down Expand Up @@ -56,6 +62,7 @@ public class Rule implements Writeable, ToXContentObject {

public static final String PRE_PACKAGED_RULES_INDEX = ".opensearch-sap-pre-packaged-rules-config";
public static final String CUSTOM_RULES_INDEX = ".opensearch-sap-custom-rules-config";
public static final String AGGREGATION_QUERIES = "aggregationQueries";

public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
Rule.class,
Expand Down Expand Up @@ -95,10 +102,12 @@ public class Rule implements Writeable, ToXContentObject {

private String rule;

private List<Value> aggregationQueries;

public Rule(String id, Long version, String title, String category, String logSource,
String description, List<Value> references, List<Value> tags, String level,
List<Value> falsePositives, String author, String status, Instant date,
List<Value> queries, List<Value> queryFieldNames, String rule) {
List<Value> queries, List<Value> queryFieldNames, String rule, List<Value> aggregationQueries) {
this.id = id != null? id: NO_ID;
this.version = version != null? version: NO_VERSION;

Expand All @@ -121,10 +130,11 @@ public Rule(String id, Long version, String title, String category, String logSo
this.queries = queries;
this.queryFieldNames = queryFieldNames;
this.rule = rule;
this.aggregationQueries = aggregationQueries;
}

public Rule(String id, Long version, SigmaRule rule, String category,
List<String> queries, List<String> queryFieldNames, String original) {
List<Object> queries, List<String> queryFieldNames, String original) {
this(
id,
version,
Expand All @@ -141,9 +151,11 @@ public Rule(String id, Long version, SigmaRule rule, String category,
rule.getAuthor(),
rule.getStatus().toString(),
Instant.ofEpochMilli(rule.getDate().getTime()),
queries.stream().map(Value::new).collect(Collectors.toList()),
queries.stream().filter(query -> !(query instanceof AggregationQueries)).map(query -> new Value(query.toString())).collect(Collectors.toList()),
queryFieldNames.stream().map(Value::new).collect(Collectors.toList()),
original);
original,
// If one of the queries is AggregationQuery -> the whole rule can be considered as Agg
queries.stream().filter(query -> query instanceof AggregationQueries).map(it -> new Value(it.toString())).collect(Collectors.toList()));
}

public Rule(StreamInput sin) throws IOException {
Expand All @@ -163,7 +175,9 @@ public Rule(StreamInput sin) throws IOException {
sin.readInstant(),
sin.readList(Value::readFrom),
sin.readList(Value::readFrom),
sin.readString());
sin.readString(),
sin.readList(Value::readFrom)
);
}

@Override
Expand All @@ -190,6 +204,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(queryFieldNames);

out.writeString(rule);
out.writeCollection(aggregationQueries);
}

@Override
Expand Down Expand Up @@ -233,6 +248,10 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten
queryFieldNamesArray = queryFieldNames.toArray(queryFieldNamesArray);
builder.field(QUERY_FIELD_NAMES, queryFieldNamesArray);

Value[] aggregationsArray = new Value[]{};
aggregationsArray = aggregationQueries.toArray(aggregationsArray);
builder.field(AGGREGATION_QUERIES, aggregationsArray);

builder.field(RULE, rule);
if (params.paramAsBoolean("with_type", false)) {
builder.endObject();
Expand Down Expand Up @@ -278,6 +297,7 @@ public static Rule parse(XContentParser xcp, String id, Long version) throws IOE
List<Value> queries = new ArrayList<>();
List<Value> queryFields = new ArrayList<>();
String original = null;
List<Value> aggregationQueries = new ArrayList<>();

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -342,6 +362,11 @@ public static Rule parse(XContentParser xcp, String id, Long version) throws IOE
case RULE:
original = xcp.text();
break;
case AGGREGATION_QUERIES:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
aggregationQueries.add(Value.parse(xcp));
}
default:
xcp.skipChildren();
}
Expand All @@ -363,7 +388,8 @@ public static Rule parse(XContentParser xcp, String id, Long version) throws IOE
date,
queries,
queryFields,
Objects.requireNonNull(original, "Rule String is null")
Objects.requireNonNull(original, "Rule String is null"),
aggregationQueries
);
}

Expand Down Expand Up @@ -442,4 +468,21 @@ public List<Value> getQueries() {
public List<Value> getQueryFieldNames() {
return queryFieldNames;
}

public List<Value> getAggregationQueries() { return aggregationQueries; }

public boolean isAggregationRule() {
return aggregationQueries != null && !aggregationQueries.isEmpty();
}

public List<AggregationItem> getAggregationItemsFromRule () throws SigmaError {
SigmaRule sigmaRule = SigmaRule.fromYaml(rule, true);
List<AggregationItem> aggregationItems = new ArrayList<>();
for (SigmaCondition condition: sigmaRule.getDetection().getParsedCondition()) {
Pair<ConditionItem, AggregationItem> parsedItems = condition.parsed();
AggregationItem aggItem = parsedItems.getRight();
aggregationItems.add(aggItem);
}
return aggregationItems;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.securityanalytics.rules.backend;

import java.util.Locale;
import org.apache.commons.lang3.NotImplementedException;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MedianAbsoluteDeviationAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MinAggregationBuilder;
import org.opensearch.search.aggregations.metrics.SumAggregationBuilder;
import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder;

public final class AggregationBuilders {

/**
* Finds the builder aggregation based on the forwarded function
*
* @param aggregationFunction Aggregation function
* @param name Name of the aggregation
* @return Aggregation builder
*/
public static AggregationBuilder getAggregationBuilderByFunction(String aggregationFunction, String name) {
AggregationBuilder aggregationBuilder;
switch (aggregationFunction.toLowerCase(Locale.ROOT)) {
case AvgAggregationBuilder.NAME:
aggregationBuilder = new AvgAggregationBuilder(name).field(name);
break;
case MaxAggregationBuilder.NAME:
aggregationBuilder = new MaxAggregationBuilder(name).field(name);
break;
case MedianAbsoluteDeviationAggregationBuilder.NAME:
aggregationBuilder = new MedianAbsoluteDeviationAggregationBuilder(name).field(name);
break;
case MinAggregationBuilder.NAME:
aggregationBuilder = new MinAggregationBuilder(name).field(name);
break;
case SumAggregationBuilder.NAME:
aggregationBuilder = new SumAggregationBuilder(name).field(name);
break;
case TermsAggregationBuilder.NAME:
aggregationBuilder = new TermsAggregationBuilder(name).field(name);
break;
case "count":
aggregationBuilder = new ValueCountAggregationBuilder(name).field(name);
break;
default:
throw new NotImplementedException(String.format(Locale.getDefault(), "Aggregation %s not supported by the backend", aggregationFunction));
}
return aggregationBuilder;
}
}
Loading