diff --git a/src/main/java/org/opensearch/securityanalytics/action/CorrelatedFindingAction.java b/src/main/java/org/opensearch/securityanalytics/action/CorrelatedFindingAction.java index d0551c505..f41bdc4aa 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/CorrelatedFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/action/CorrelatedFindingAction.java @@ -8,7 +8,7 @@ public class CorrelatedFindingAction extends ActionType { public static final CorrelatedFindingAction INSTANCE = new CorrelatedFindingAction(); - public static final String NAME = "cluster:admin/opensearch/securityanalytics/findings/correlated"; + public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlations/findings"; public CorrelatedFindingAction() { super(NAME, CorrelatedFindingResponse::new); diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 0ae4ea129..5e4bb6629 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -185,7 +185,8 @@ public void onResponse(MultiSearchResponse items) { categoryToQueriesMap.put(query.getCategory(), correlationQueries); } } - searchFindingsByTimestamp(detectorType, categoryToQueriesMap); + searchFindingsByTimestamp(detectorType, categoryToQueriesMap, + filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList())); } @Override @@ -194,7 +195,7 @@ public void onFailure(Exception e) { } }); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of()); } } @@ -202,7 +203,7 @@ public void onFailure(Exception e) { * this method searches for parent findings given the log category & correlation time window & collects all related docs * for them. */ - private void searchFindingsByTimestamp(String detectorType, Map> categoryToQueriesMap) { + private void searchFindingsByTimestamp(String detectorType, Map> categoryToQueriesMap, List correlationRules) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List>> categoryToQueriesPairs = new ArrayList<>(); @@ -255,7 +256,7 @@ public void onResponse(MultiSearchResponse items) { relatedDocIds)); ++idx; } - searchDocsWithFilterKeys(detectorType, relatedDocsMap); + searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules); } @Override @@ -264,14 +265,14 @@ public void onFailure(Exception e) { } }); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); } } /** * Given the related docs from parent findings, this method filters only those related docs which match parent join criteria. */ - private void searchDocsWithFilterKeys(String detectorType, Map relatedDocsMap) { + private void searchDocsWithFilterKeys(String detectorType, Map relatedDocsMap, List correlationRules) { MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -318,7 +319,7 @@ public void onResponse(MultiSearchResponse items) { filteredRelatedDocIds.put(categories.get(idx), docIds); ++idx; } - getCorrelatedFindings(detectorType, filteredRelatedDocIds); + getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules); } @Override @@ -327,7 +328,7 @@ public void onFailure(Exception e) { } }); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); } } @@ -335,7 +336,7 @@ public void onFailure(Exception e) { * Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for * the finding to be correlated. */ - private void getCorrelatedFindings(String detectorType, Map> filteredRelatedDocIds) { + private void getCorrelatedFindings(String detectorType, Map> filteredRelatedDocIds, List correlationRules) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -390,7 +391,7 @@ public void onResponse(MultiSearchResponse items) { } ++idx; } - correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings); + correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); } @Override @@ -399,7 +400,7 @@ public void onFailure(Exception e) { } }); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); } } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java index 80944c08b..1e91835f6 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java @@ -58,7 +58,7 @@ public VectorEmbeddingsEngine(Client client, TimeValue indexTimeout, long corrTi this.correlateFindingAction = correlateFindingAction; } - public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List correlatedFindings, float timestampFeature) { + public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List correlatedFindings, float timestampFeature, List correlationRules) { long findingTimestamp = finding.getTimestamp().toEpochMilli(); MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( "root", true @@ -172,6 +172,7 @@ public void onResponse(MultiSearchResponse items) { corrBuilder.field("corr_vector", corrVector); corrBuilder.field("recordType", "finding-finding"); corrBuilder.field("scoreTimestamp", 0L); + corrBuilder.field("corrRules", correlationRules); corrBuilder.endObject(); IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_INDEX) diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelatedFinding.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelatedFinding.java index cfac3695e..d5f68339b 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/CorrelatedFinding.java +++ b/src/main/java/org/opensearch/securityanalytics/model/CorrelatedFinding.java @@ -13,6 +13,8 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; public class CorrelatedFinding implements Writeable, ToXContentObject { @@ -24,16 +26,20 @@ public class CorrelatedFinding implements Writeable, ToXContentObject { private String logType2; + private List correlationRules; + protected static final String FINDING1_FIELD = "finding1"; protected static final String LOGTYPE1_FIELD = "logType1"; protected static final String FINDING2_FIELD = "finding2"; protected static final String LOGTYPE2_FIELD = "logType2"; + protected static final String RULES_FIELD = "rules"; - public CorrelatedFinding(String finding1, String logType1, String finding2, String logType2) { + public CorrelatedFinding(String finding1, String logType1, String finding2, String logType2, List correlationRules) { this.finding1 = finding1; this.logType1 = logType1; this.finding2 = finding2; this.logType2 = logType2; + this.correlationRules = correlationRules; } public CorrelatedFinding(StreamInput sin) throws IOException { @@ -41,7 +47,8 @@ public CorrelatedFinding(StreamInput sin) throws IOException { sin.readString(), sin.readString(), sin.readString(), - sin.readString() + sin.readString(), + sin.readStringList() ); } @@ -51,6 +58,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(logType1); out.writeString(finding2); out.writeString(logType2); + out.writeStringCollection(correlationRules); } @Override @@ -59,7 +67,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(FINDING1_FIELD, finding1) .field(LOGTYPE1_FIELD, logType1) .field(FINDING2_FIELD, finding2) - .field(LOGTYPE2_FIELD, logType2); + .field(LOGTYPE2_FIELD, logType2) + .field(RULES_FIELD, correlationRules); return builder.endObject(); } @@ -68,6 +77,7 @@ public static CorrelatedFinding parse(XContentParser xcp) throws IOException { String logType1 = null; String finding2 = null; String logType2 = null; + List correlationRules = new ArrayList<>(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -87,11 +97,17 @@ public static CorrelatedFinding parse(XContentParser xcp) throws IOException { case LOGTYPE2_FIELD: logType2 = xcp.text(); break; + case RULES_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + correlationRules.add(xcp.text()); + } + break; default: xcp.skipChildren(); } } - return new CorrelatedFinding(finding1, logType1, finding2, logType2); + return new CorrelatedFinding(finding1, logType1, finding2, logType2, correlationRules); } public static CorrelatedFinding readFrom(StreamInput sin) throws IOException { diff --git a/src/main/java/org/opensearch/securityanalytics/model/FindingWithScore.java b/src/main/java/org/opensearch/securityanalytics/model/FindingWithScore.java index d4d4cd2a3..2177d076e 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/FindingWithScore.java +++ b/src/main/java/org/opensearch/securityanalytics/model/FindingWithScore.java @@ -13,12 +13,15 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; public class FindingWithScore implements Writeable, ToXContentObject { protected static final String FINDING = "finding"; protected static final String DETECTOR_TYPE = "detector_type"; protected static final String SCORE = "score"; + protected static final String RULES = "rules"; private String finding; @@ -26,17 +29,21 @@ public class FindingWithScore implements Writeable, ToXContentObject { private Double score; - public FindingWithScore(String finding, String detectorType, Double score) { + private List rules; + + public FindingWithScore(String finding, String detectorType, Double score, List rules) { this.finding = finding; this.detectorType = detectorType; this.score = score; + this.rules = rules; } public FindingWithScore(StreamInput sin) throws IOException { this( sin.readString(), sin.readString(), - sin.readDouble() + sin.readDouble(), + sin.readStringList() ); } @@ -45,6 +52,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(finding); out.writeString(detectorType); out.writeDouble(score); + out.writeStringCollection(rules); } @Override @@ -53,6 +61,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(FINDING, finding) .field(DETECTOR_TYPE, detectorType) .field(SCORE, score) + .field(RULES, rules) .endObject(); return builder; } @@ -61,6 +70,7 @@ public static FindingWithScore parse(XContentParser xcp) throws IOException { String finding = null; String detectorType = null; Double score = null; + List rules = new ArrayList<>(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -77,9 +87,17 @@ public static FindingWithScore parse(XContentParser xcp) throws IOException { case SCORE: score = xcp.doubleValue(); break; + case RULES: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + rules.add(xcp.text()); + } + break; + default: + xcp.skipChildren(); } } - return new FindingWithScore(finding, detectorType, score); + return new FindingWithScore(finding, detectorType, score, rules); } public static FindingWithScore readFrom(StreamInput sin) throws IOException { diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 57f491ae3..f84b433db 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -118,6 +118,7 @@ public TransportCorrelateFindingAction(TransportService transportService, @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { try { + log.info("hit here1"); PublishFindingsRequest transformedRequest = transformRequest(request); if (!this.correlationIndices.correlationIndexExists()) { @@ -157,6 +158,7 @@ public void onFailure(Exception e) { log.error(ex); } } else { + log.info("hit here2"); AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); correlateFindingAction.start(); } @@ -184,9 +186,11 @@ public class AsyncCorrelateFindingAction { this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this); this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this); + log.info("hit here5"); } void start() { + log.info("hit here4"); TransportCorrelateFindingAction.this.threadPool.getThreadContext().stashContext(); String monitorId = request.getMonitorId(); Finding finding = request.getFinding(); @@ -246,8 +250,9 @@ public void onFailure(Exception e) { } } - public void initCorrelationIndex(String detectorType, Map> correlatedFindings) { + public void initCorrelationIndex(String detectorType, Map> correlatedFindings, List correlationRules) { try { + log.info("hit here6"); if (!IndexUtils.correlationIndexUpdated) { IndexUtils.updateIndexMapping( CorrelationIndices.CORRELATION_INDEX, @@ -257,7 +262,7 @@ public void initCorrelationIndex(String detectorType, Map> public void onResponse(AcknowledgedResponse response) { if (response.isAcknowledged()) { IndexUtils.correlationIndexUpdated(); - getTimestampFeature(detectorType, correlatedFindings, null); + getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); } else { onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } @@ -270,14 +275,15 @@ public void onFailure(Exception e) { } ); } else { - getTimestampFeature(detectorType, correlatedFindings, null); + getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); } } catch (IOException ex) { onFailures(ex); } } - public void getTimestampFeature(String detectorType, Map> correlatedFindings, Finding orphanFinding) { + public void getTimestampFeature(String detectorType, Map> correlatedFindings, Finding orphanFinding, List correlationRules) { + log.info("hit here7"); long findingTimestamp = this.request.getFinding().getTimestamp().toEpochMilli(); BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); @@ -318,7 +324,7 @@ public void onResponse(IndexResponse response) { } for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue()); + Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules); } } else { vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue()); @@ -341,7 +347,7 @@ public void onFailure(Exception e) { } for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - timestampFeature); + timestampFeature, correlationRules); } } else { vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature); diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportListCorrelationAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportListCorrelationAction.java index 5aa8da895..09a488175 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportListCorrelationAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportListCorrelationAction.java @@ -40,7 +40,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -public class TransportListCorrelationAction extends HandledTransportAction { +public class TransportListCorrelationAction extends HandledTransportAction implements SecureTransportAction { private static final Logger log = LogManager.getLogger(TransportListCorrelationAction.class); @@ -91,6 +91,7 @@ class AsyncListCorrelationAction { this.response =new AtomicReference<>(); } + @SuppressWarnings("unchecked") void start() { Long startTimestamp = request.getStartTimestamp(); Long endTimestamp = request.getEndTimestamp(); @@ -106,7 +107,7 @@ void start() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(queryBuilder); searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); + searchSourceBuilder.size(10000); SearchRequest searchRequest = new SearchRequest(); searchRequest.indices(CorrelationIndices.CORRELATION_INDEX); searchRequest.source(searchSourceBuilder); @@ -128,7 +129,8 @@ public void onResponse(SearchResponse response) { source.get("finding1").toString(), source.get("logType").toString().split("-")[0], source.get("finding2").toString(), - source.get("logType").toString().split("-")[1]); + source.get("logType").toString().split("-")[1], + (List) source.get("corrRules")); correlatedFindings.add(correlatedFinding); } onOperation(new ListCorrelationsResponse(correlatedFindings)); diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationAction.java index 296e42457..dde82e31f 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationAction.java @@ -41,12 +41,14 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -public class TransportSearchCorrelationAction extends HandledTransportAction { +public class TransportSearchCorrelationAction extends HandledTransportAction implements SecureTransportAction { private static final Logger log = LogManager.getLogger(TransportSearchCorrelationAction.class); @@ -97,6 +99,7 @@ class AsyncSearchCorrelationAction { this.response =new AtomicReference<>(); } + @SuppressWarnings("unchecked") void start() { String findingId = request.getFindingId(); Detector.DetectorType detectorType = request.getDetectorType(); @@ -191,7 +194,7 @@ public void onResponse(SearchResponse response) { @Override public void onResponse(MultiSearchResponse items) { MultiSearchResponse.Item[] responses = items.getResponses(); - Map, Double> correlatedFindings = new HashMap<>(); + Map, Pair>> correlatedFindings = new HashMap<>(); for (MultiSearchResponse.Item response : responses) { if (response.isFailure()) { @@ -206,26 +209,37 @@ public void onResponse(MultiSearchResponse items) { Pair findingKey1 = Pair.of(source.get("finding1").toString(), source.get("logType").toString().split("-")[0]); if (correlatedFindings.containsKey(findingKey1)) { - correlatedFindings.put(findingKey1, Math.max(correlatedFindings.get(findingKey1), hit.getScore())); + double score = Math.max(correlatedFindings.get(findingKey1).getLeft(), hit.getScore()); + Set rules = correlatedFindings.get(findingKey1).getRight(); + rules.addAll((List) source.get("corrRules")); + + correlatedFindings.put(findingKey1, Pair.of(score, rules)); } else { - correlatedFindings.put(findingKey1, (double) hit.getScore()); + Set rules = new HashSet<>((List) source.get("corrRules")); + correlatedFindings.put(findingKey1, Pair.of((double) hit.getScore(), rules)); } } if (!source.get("finding2").toString().equals(findingId)) { Pair findingKey2 = Pair.of(source.get("finding2").toString(), source.get("logType").toString().split("-")[1]); if (correlatedFindings.containsKey(findingKey2)) { - correlatedFindings.put(findingKey2, Math.max(correlatedFindings.get(findingKey2), hit.getScore())); + double score = Math.max(correlatedFindings.get(findingKey2).getLeft(), hit.getScore()); + Set rules = correlatedFindings.get(findingKey2).getRight(); + rules.addAll((List) source.get("corrRules")); + + correlatedFindings.put(findingKey2, Pair.of(score, rules)); } else { - correlatedFindings.put(findingKey2, (double) hit.getScore()); + Set rules = new HashSet<>((List) source.get("corrRules")); + correlatedFindings.put(findingKey2, Pair.of((double) hit.getScore(), rules)); } } } } List findingWithScores = new ArrayList<>(); - for (Map.Entry, Double> correlatedFinding: correlatedFindings.entrySet()) { - findingWithScores.add(new FindingWithScore(correlatedFinding.getKey().getKey(), correlatedFinding.getKey().getValue(), correlatedFinding.getValue())); + for (Map.Entry, Pair>> correlatedFinding: correlatedFindings.entrySet()) { + findingWithScores.add(new FindingWithScore(correlatedFinding.getKey().getKey(), correlatedFinding.getKey().getValue(), + correlatedFinding.getValue().getLeft(), new ArrayList<>(correlatedFinding.getValue().getRight()))); } onOperation(new CorrelatedFindingResponse(findingWithScores)); diff --git a/src/main/resources/mappings/correlation.json b/src/main/resources/mappings/correlation.json index 25b4176b2..8fbe2ea71 100644 --- a/src/main/resources/mappings/correlation.json +++ b/src/main/resources/mappings/correlation.json @@ -37,6 +37,15 @@ }, "scoreTimestamp": { "type": "long" + }, + "corrRules": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } } } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java index 8ba0fb82d..cb231c5b8 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java @@ -30,7 +30,7 @@ public class CorrelationEngineRestApiIT extends SecurityAnalyticsRestTestCase { @SuppressWarnings("unchecked") - public void testBasicCorrelationEngineWorkflow() throws IOException, InterruptedException { + public void testBasicCorrelationEngineWorkflow() throws IOException { LogIndices indices = createIndices(); String vpcFlowMonitorId = createVpcFlowDetector(indices.vpcFlowsIndex); @@ -39,21 +39,13 @@ public void testBasicCorrelationEngineWorkflow() throws IOException, Interrupted String appLogsMonitorId = createAppLogsDetector(indices.appLogsIndex); String s3MonitorId = createS3Detector(indices.s3AccessLogsIndex); - createNetworkToAdLdapToWindowsRule(indices); + String ruleId = createNetworkToAdLdapToWindowsRule(indices); createWindowsToAppLogsToS3LogsRule(indices); - indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc()); - Response executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap()); + indexDoc(indices.adLdapLogsIndex, "22", randomAdLdapDoc()); + Response executeResponse = executeAlertingMonitor(adLdapMonitorId, Collections.emptyMap()); Map executeResults = entityAsMap(executeResponse); int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); - Assert.assertEquals(1, noOfSigmaRuleMatches); - - Thread.sleep(2000); - - indexDoc(indices.adLdapLogsIndex, "22", randomAdLdapDoc()); - executeResponse = executeAlertingMonitor(adLdapMonitorId, Collections.emptyMap()); - executeResults = entityAsMap(executeResponse); - noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); Assert.assertEquals(0, noOfSigmaRuleMatches); indexDoc(indices.windowsIndex, "2", randomDoc()); @@ -74,6 +66,12 @@ public void testBasicCorrelationEngineWorkflow() throws IOException, Interrupted noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); Assert.assertEquals(0, noOfSigmaRuleMatches); + indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc()); + executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(1, noOfSigmaRuleMatches); + // Call GetFindings API Map params = new HashMap<>(); params.put("detectorType", "test_windows"); @@ -83,6 +81,9 @@ public void testBasicCorrelationEngineWorkflow() throws IOException, Interrupted List> correlatedFindings = searchCorrelatedFindings(finding, "test_windows", 300000L, 10); Assert.assertEquals(1, correlatedFindings.size()); + Assert.assertTrue(correlatedFindings.get(0).get("rules") instanceof List); + Assert.assertEquals(1, ((List) correlatedFindings.get(0).get("rules")).size()); + Assert.assertEquals(ruleId, ((List) correlatedFindings.get(0).get("rules")).get(0)); } @SuppressWarnings("unchecked") @@ -96,19 +97,19 @@ public void testListCorrelationsWorkflow() throws IOException, InterruptedExcept createNetworkToAdLdapToWindowsRule(indices); Thread.sleep(30000); - indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc()); - Response executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap()); + indexDoc(indices.windowsIndex, "2", randomDoc()); + Response executeResponse = executeAlertingMonitor(testWindowsMonitorId, Collections.emptyMap()); Map executeResults = entityAsMap(executeResponse); int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); - Assert.assertEquals(1, noOfSigmaRuleMatches); + Assert.assertEquals(5, noOfSigmaRuleMatches); Thread.sleep(30000); - indexDoc(indices.windowsIndex, "2", randomDoc()); - executeResponse = executeAlertingMonitor(testWindowsMonitorId, Collections.emptyMap()); + indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc()); + executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap()); executeResults = entityAsMap(executeResponse); noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); - Assert.assertEquals(5, noOfSigmaRuleMatches); + Assert.assertEquals(1, noOfSigmaRuleMatches); Thread.sleep(30000); Long endTime = System.currentTimeMillis(); @@ -132,7 +133,7 @@ private LogIndices createIndices() throws IOException { return indices; } - private void createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOException { + private String createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOException { CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "dstaddr:4.5.6.7", "network"); CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap"); CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "Domain:NTAUTHORI*", "test_windows"); @@ -143,9 +144,10 @@ private void createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOExc Response response = client().performRequest(request); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); } - private void createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOException { + private String createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOException { CorrelationQuery query1 = new CorrelationQuery(indices.windowsIndex, "HostName:EC2AMAZ-EPO7HKA", "test_windows"); CorrelationQuery query2 = new CorrelationQuery(indices.appLogsIndex, "endpoint:\\/customer_records.txt", "ad_ldap"); CorrelationQuery query4 = new CorrelationQuery(indices.s3AccessLogsIndex, "aws.cloudtrail.eventName:ReplicateObject", "s3"); @@ -156,6 +158,7 @@ private void createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOExc Response response = client().performRequest(request); Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); } @SuppressWarnings("unchecked")