diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java index 1e0cb6113..4f2009459 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java @@ -5,6 +5,7 @@ package org.opensearch.securityanalytics.action; import java.io.IOException; +import java.time.Instant; import java.util.Locale; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -24,6 +25,10 @@ public class GetAlertsRequest extends ActionRequest { private String severityLevel; private String alertState; + private Instant startTime; + + private Instant endTime; + public static final String DETECTOR_ID = "detector_id"; public GetAlertsRequest( @@ -31,7 +36,9 @@ public GetAlertsRequest( String logType, Table table, String severityLevel, - String alertState + String alertState, + Instant startTime, + Instant endTime ) { super(); this.detectorId = detectorId; @@ -39,6 +46,8 @@ public GetAlertsRequest( this.table = table; this.severityLevel = severityLevel; this.alertState = alertState; + this.startTime = startTime; + this.endTime = endTime; } public GetAlertsRequest(StreamInput sin) throws IOException { this( @@ -46,7 +55,9 @@ public GetAlertsRequest(StreamInput sin) throws IOException { sin.readOptionalString(), Table.readFrom(sin), sin.readString(), - sin.readString() + sin.readString(), + sin.readOptionalInstant(), + sin.readOptionalInstant() ); } @@ -68,6 +79,8 @@ public void writeTo(StreamOutput out) throws IOException { table.writeTo(out); out.writeString(severityLevel); out.writeString(alertState); + out.writeOptionalInstant(startTime); + out.writeOptionalInstant(endTime); } public String getDetectorId() { @@ -89,4 +102,12 @@ public String getAlertState() { public String getLogType() { return logType; } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } } diff --git a/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java b/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java index a61fe9d35..730edbf2c 100644 --- a/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java +++ b/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java @@ -19,6 +19,9 @@ import org.opensearch.commons.alerting.model.Alert; import org.opensearch.commons.alerting.model.Table; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.securityanalytics.action.AckAlertsResponse; import org.opensearch.securityanalytics.action.AlertDto; import org.opensearch.securityanalytics.action.GetAlertsResponse; @@ -29,6 +32,7 @@ import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -66,6 +70,8 @@ public void getAlertsByDetectorId( Table table, String severityLevel, String alertState, + Instant startTime, + Instant endTime, ActionListener listener ) { this.client.execute(GetDetectorAction.INSTANCE, new GetDetectorRequest(detectorId, -3L), new ActionListener<>() { @@ -88,6 +94,8 @@ public void onResponse(GetDetectorResponse getDetectorResponse) { table, severityLevel, alertState, + startTime, + endTime, new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { @@ -129,9 +137,11 @@ public void getAlertsByMonitorIds( Table table, String severityLevel, String alertState, + Instant startTime, + Instant endTime, ActionListener listener ) { - + BoolQueryBuilder boolQueryBuilder = getBoolQueryBuilder(startTime, endTime); org.opensearch.commons.alerting.action.GetAlertsRequest req = new org.opensearch.commons.alerting.action.GetAlertsRequest( table, @@ -141,7 +151,8 @@ public void getAlertsByMonitorIds( alertIndex, monitorIds, null, - null + null, + boolQueryBuilder ); AlertingPluginInterface.INSTANCE.getAlerts((NodeClient) client, req, new ActionListener<>() { @@ -178,6 +189,8 @@ public void getAlerts( Table table, String severityLevel, String alertState, + Instant startTime, + Instant endTime, ActionListener listener ) { if (detectors.size() == 0) { @@ -204,6 +217,8 @@ public void getAlerts( table, severityLevel, alertState, + startTime, + endTime, new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { @@ -246,7 +261,10 @@ private AlertDto mapAlertToAlertDto(Alert alert, String detectorId) { public void getAlerts(List alertIds, Detector detector, Table table, + Instant startTime, + Instant endTime, ActionListener actionListener) { + BoolQueryBuilder boolQueryBuilder = getBoolQueryBuilder(startTime, endTime); GetAlertsRequest request = new GetAlertsRequest( table, "ALL", @@ -255,7 +273,8 @@ public void getAlerts(List alertIds, DetectorMonitorConfig.getAllAlertsIndicesPattern(detector.getDetectorType()), null, null, - alertIds); + alertIds, + boolQueryBuilder); AlertingPluginInterface.INSTANCE.getAlerts( (NodeClient) client, request, actionListener); @@ -305,4 +324,17 @@ public void onFailure(Exception e) { } } + + private static BoolQueryBuilder getBoolQueryBuilder(Instant startTime, Instant endTime) { + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + if (startTime != null && endTime != null) { + long startTimeMillis = startTime.toEpochMilli(); + long endTimeMillis = endTime.toEpochMilli(); + QueryBuilder timeRangeQuery = QueryBuilders.rangeQuery("start_time") + .from(startTimeMillis) // Greater than or equal to start time + .to(endTimeMillis); // Less than or equal to end time + boolQueryBuilder.filter(timeRangeQuery); + } + return boolQueryBuilder; + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java index 19322d0cd..0276db801 100644 --- a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java @@ -5,6 +5,8 @@ package org.opensearch.securityanalytics.resthandler; import java.io.IOException; +import java.time.DateTimeException; +import java.time.Instant; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; @@ -45,6 +47,26 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli int startIndex = request.paramAsInt("startIndex", 0); String searchString = request.param("searchString", ""); + Instant startTime = null; + String startTimeParam = request.param("startTime"); + if (startTimeParam != null && !startTimeParam.isEmpty()) { + try { + startTime = Instant.ofEpochMilli(Long.parseLong(startTimeParam)); + } catch (NumberFormatException | NullPointerException | DateTimeException e) { + startTime = Instant.now(); + } + } + + Instant endTime = null; + String endTimeParam = request.param("endTime"); + if (endTimeParam != null && !endTimeParam.isEmpty()) { + try { + endTime = Instant.ofEpochMilli(Long.parseLong(endTimeParam)); + } catch (NumberFormatException | NullPointerException | DateTimeException e) { + endTime = Instant.now(); + } + } + Table table = new Table( sortOrder, sortString, @@ -59,7 +81,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli detectorType, table, severityLevel, - alertState + alertState, + startTime, + endTime ); return channel -> client.execute( diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportAcknowledgeAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportAcknowledgeAlertsAction.java index 16679e9b2..0018a0f18 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportAcknowledgeAlertsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportAcknowledgeAlertsAction.java @@ -82,6 +82,8 @@ public void onResponse(GetDetectorResponse getDetectorResponse) { request.getAlertIds(), getDetectorResponse.getDetector(), new Table("asc", "id", null, 10000, 0, null), + null, + null, getAlertsResponseStepListener ); getAlertsResponseStepListener.whenComplete(getAlertsResponse -> { diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java index f01929fc9..c2bdd7a15 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java @@ -91,6 +91,8 @@ protected void doExecute(Task task, GetAlertsRequest request, ActionListener { - ActionListener l = invocation.getArgument(6); + ActionListener l = invocation.getArgument(8); l.onResponse(getAlertsResponse); return null; - }).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class)); + }).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(ActionListener.class)); // Call getFindingsByDetectorId Table table = new Table( @@ -205,7 +205,8 @@ public void testGetAlerts_success() { 0, null ); - alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { + alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null, + new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { assertEquals(2, (int)getAlertsResponse.getTotalAlerts()); @@ -258,10 +259,10 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { }).when(client).execute(eq(GetDetectorAction.INSTANCE), any(GetDetectorRequest.class), any(ActionListener.class)); doAnswer(invocation -> { - ActionListener l = invocation.getArgument(6); + ActionListener l = invocation.getArgument(8); l.onFailure(new IllegalArgumentException("Error getting findings")); return null; - }).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class)); + }).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(ActionListener.class)); // Call getFindingsByDetectorId Table table = new Table( @@ -272,7 +273,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { 0, null ); - alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { + alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null, + new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { fail("this test should've failed"); @@ -307,7 +309,8 @@ public void testGetFindings_getDetectorFailure() { 0, null ); - alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { + alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null, + new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { fail("this test should've failed"); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index 5b7da7a00..e6f4eff6d 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -6,16 +6,17 @@ package org.opensearch.securityanalytics.alerts; import java.io.IOException; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.apache.hc.core5.http.HttpStatus; @@ -39,19 +40,17 @@ import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.test.rest.OpenSearchRestTestCase; -import static java.util.Collections.emptyList; import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.randomAction; import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; -import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; -import static org.opensearch.securityanalytics.TestHelpers.randomDocWithIpIoc; import static org.opensearch.securityanalytics.TestHelpers.randomNetworkDoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; import static org.opensearch.securityanalytics.TestHelpers.randomRule; @@ -60,7 +59,6 @@ import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_MAX_DOCS; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_RETENTION_PERIOD; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_ROLLOVER_PERIOD; -import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; public class AlertsIT extends SecurityAnalyticsRestTestCase { @@ -181,6 +179,121 @@ public void testGetAlerts_success() throws IOException { assertEquals(((ArrayList) ackAlertsResponseMap.get("acknowledged")).size(), 1); } + @SuppressWarnings("unchecked") + public void testGetAlertsByStartTimeAndEndTimeSuccess() throws IOException, InterruptedException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + String rule = randomRule(); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", randomDetectorType()), + new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + + createAlertingMonitorConfigIndex(null); + Action triggerAction = randomAction(createDestination()); + + Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + + final String detectorId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + String monitorId = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + + indexDoc(index, "1", randomDoc()); + + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(6, noOfSigmaRuleMatches); + + Assert.assertEquals(1, ((Map) executeResults.get("trigger_results")).values().size()); + + // Call GetAlerts API + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + // TODO enable asserts here when able + Assert.assertEquals(1, getAlertsBody.get("total_alerts")); + + Instant startTime = Instant.now(); + indexDoc(index, "2", randomDoc()); + indexDoc(index, "5", randomDoc()); + + executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(6, noOfSigmaRuleMatches); + + Assert.assertEquals(1, ((Map) executeResults.get("trigger_results")).values().size()); + Instant endTime = Instant.now(); + + indexDoc(index, "4", randomDoc()); + indexDoc(index, "6", randomDoc()); + + executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(6, noOfSigmaRuleMatches); + + AtomicBoolean success = new AtomicBoolean(true); + OpenSearchRestTestCase.waitUntil( + () -> { + try { + // Call GetAlerts API + Map alertParams = new HashMap<>(); + alertParams.put("detector_id", detectorId); + alertParams.put("startTime", String.valueOf(startTime.toEpochMilli())); + alertParams.put("endTime", String.valueOf(endTime.toEpochMilli())); + Response currGetAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, alertParams, null); + Map currGetAlertsBody = asMap(currGetAlertsResponse); + // TODO enable asserts here when able + success.set(Integer.parseInt(currGetAlertsBody.get("total_alerts").toString()) == 2); + } catch (IOException ex) { + success.set(false); + } + return success.get(); + }, 2, TimeUnit.MINUTES + ); + Assert.assertTrue(success.get()); + } + public void testGetAlerts_noDetector_failure() throws IOException { // Call GetAlerts API Map params = new HashMap<>();