diff --git a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java index a696971da..2cd4e4d42 100644 --- a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java +++ b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java @@ -5,6 +5,7 @@ package org.opensearch.securityanalytics.findings; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -16,6 +17,7 @@ import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.model.DocLevelQuery; import org.opensearch.commons.alerting.model.FindingWithDocs; import org.opensearch.commons.alerting.model.Table; import org.opensearch.rest.RestStatus; @@ -83,9 +85,9 @@ public void onFailure(Exception e) { }; // monitor --> detectorId mapping - Map monitorToDetectorMapping = new HashMap<>(); + Map monitorToDetectorMapping = new HashMap<>(); detector.getMonitorIds().forEach( - monitorId -> monitorToDetectorMapping.put(monitorId, detector.getId()) + monitorId -> monitorToDetectorMapping.put(monitorId, detector) ); // Get findings for all monitor ids FindingsService.this.getFindingsByMonitorIds( @@ -112,7 +114,7 @@ public void onFailure(Exception e) { * @param listener ActionListener to get notified on response or error */ public void getFindingsByMonitorIds( - Map monitorToDetectorMapping, + Map monitorToDetectorMapping, List monitorIds, String findingIndexName, Table table, @@ -169,11 +171,11 @@ public void getFindings( List allMonitorIds = new ArrayList<>(); // Used to convert monitorId back to detectorId to store in result FindingDto - Map monitorToDetectorMapping = new HashMap<>(); + Map monitorToDetectorMapping = new HashMap<>(); detectors.forEach(detector -> { // monitor --> detector map detector.getMonitorIds().forEach( - monitorId -> monitorToDetectorMapping.put(monitorId, detector.getId()) + monitorId -> monitorToDetectorMapping.put(monitorId, detector) ); // all monitorIds allMonitorIds.addAll(detector.getMonitorIds()); @@ -201,13 +203,21 @@ public void onFailure(Exception e) { ); } - public FindingDto mapFindingWithDocsToFindingDto(FindingWithDocs findingWithDocs, String detectorId) { + public FindingDto mapFindingWithDocsToFindingDto(FindingWithDocs findingWithDocs, Detector detector) { + List docLevelQueries = findingWithDocs.getFinding().getDocLevelQueries(); + if (docLevelQueries.isEmpty()) { // this is finding generated by a bucket level monitor + for (Map.Entry entry : detector.getRuleIdMonitorIdMap().entrySet()) { + if(entry.getValue().equals(findingWithDocs.getFinding().getMonitorId())) { + docLevelQueries = Collections.singletonList(new DocLevelQuery(entry.getKey(),"","",Collections.emptyList())); + } + } + } return new FindingDto( - detectorId, + detector.getId(), findingWithDocs.getFinding().getId(), findingWithDocs.getFinding().getRelatedDocIds(), findingWithDocs.getFinding().getIndex(), - findingWithDocs.getFinding().getDocLevelQueries(), + docLevelQueries, findingWithDocs.getFinding().getTimestamp(), findingWithDocs.getDocuments() ); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java index 5f03ab958..6344067ef 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorRestApiIT.java @@ -309,9 +309,9 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { Map responseBody = asMap(createResponse); - String createdRuleId = responseBody.get("_id").toString(); + String detectorId = responseBody.get("_id").toString(); - DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdRuleId)), + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(detectorId)), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector = randomDetectorWithInputs(List.of(input)); @@ -320,11 +320,11 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { responseBody = asMap(createResponse); - createdRuleId = responseBody.get("_id").toString(); + detectorId = responseBody.get("_id").toString(); int createdVersion = Integer.parseInt(responseBody.get("_version").toString()); - Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdRuleId); + Assert.assertNotEquals("response is missing Id", Detector.NO_ID, detectorId); Assert.assertTrue("incorrect version", createdVersion > 0); - Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, createdRuleId), createResponse.getHeader("Location")); + Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, detectorId), createResponse.getHeader("Location")); Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("rule_topic_index")); Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("findings_index")); Assert.assertFalse(((Map) responseBody.get("detector")).containsKey("alert_index")); @@ -332,7 +332,7 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { String request = "{\n" + " \"query\" : {\n" + " \"match\":{\n" + - " \"_id\": \"" + createdRuleId + "\"\n" + + " \"_id\": \"" + detectorId + "\"\n" + " }\n" + " }\n" + "}"; @@ -370,11 +370,20 @@ public void testCreatingADetectorWithAggregationRules() throws IOException { Map executeResults = entityAsMap(executeResponse); // verify bucket level monitor findings Map params = new HashMap<>(); - params.put("detector_id", createdRuleId); + params.put("detector_id", detectorId); Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); Map getFindingsBody = entityAsMap(getFindingsResponse); assertNotNull(getFindingsBody); Assert.assertEquals(1, getFindingsBody.get("total_findings")); + List findings = (List) getFindingsBody.get("findings"); + Assert.assertEquals(findings.size(), 1); + HashMap finding = (HashMap) findings.get(0); + Assert.assertTrue(finding.containsKey("queries")); + HashMap docLevelQuery = (HashMap) ((List) finding.get("queries")).get(0); + String ruleId = docLevelQuery.get("id").toString(); + Response getResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), null); + String getDetectorResponseString = new String(getResponse.getEntity().getContent().readAllBytes()); + Assert.assertTrue(getDetectorResponseString.contains(ruleId)); } public void testUpdateADetector() throws IOException { String index = createTestIndex(randomIndex(), windowsIndexMapping());