Skip to content

Commit

Permalink
Add API to invoke a workflow
Browse files Browse the repository at this point in the history
Signed-off-by: Chase Engelbrecht <engechas@amazon.com>
  • Loading branch information
engechas committed Jan 30, 2024
1 parent ffc183f commit 3565081
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.opensearch.alerting.resthandler.RestGetWorkflowAction
import org.opensearch.alerting.resthandler.RestGetWorkflowAlertsAction
import org.opensearch.alerting.resthandler.RestIndexMonitorAction
import org.opensearch.alerting.resthandler.RestIndexWorkflowAction
import org.opensearch.alerting.resthandler.RestRunWorkflowAction
import org.opensearch.alerting.resthandler.RestSearchEmailAccountAction
import org.opensearch.alerting.resthandler.RestSearchEmailGroupAction
import org.opensearch.alerting.resthandler.RestSearchMonitorAction
Expand All @@ -63,6 +64,7 @@ import org.opensearch.alerting.transport.TransportGetWorkflowAction
import org.opensearch.alerting.transport.TransportGetWorkflowAlertsAction
import org.opensearch.alerting.transport.TransportIndexMonitorAction
import org.opensearch.alerting.transport.TransportIndexWorkflowAction
import org.opensearch.alerting.transport.TransportRunWorkflowAction
import org.opensearch.alerting.transport.TransportSearchEmailAccountAction
import org.opensearch.alerting.transport.TransportSearchEmailGroupAction
import org.opensearch.alerting.transport.TransportSearchMonitorAction
Expand All @@ -89,6 +91,7 @@ import org.opensearch.commons.alerting.model.QueryLevelTrigger
import org.opensearch.commons.alerting.model.ScheduledJob
import org.opensearch.commons.alerting.model.SearchInput
import org.opensearch.commons.alerting.model.Workflow
import org.opensearch.commons.alerting.settings.SharedSettings
import org.opensearch.core.action.ActionResponse
import org.opensearch.core.common.io.stream.NamedWriteableRegistry
import org.opensearch.core.common.io.stream.StreamInput
Expand Down Expand Up @@ -183,7 +186,8 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R
RestGetWorkflowAlertsAction(),
RestGetFindingsAction(),
RestGetWorkflowAction(),
RestDeleteWorkflowAction()
RestDeleteWorkflowAction(),
RestRunWorkflowAction()
)
}

Expand All @@ -210,6 +214,7 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R
ActionPlugin.ActionHandler(AlertingActions.INDEX_WORKFLOW_ACTION_TYPE, TransportIndexWorkflowAction::class.java),
ActionPlugin.ActionHandler(AlertingActions.GET_WORKFLOW_ACTION_TYPE, TransportGetWorkflowAction::class.java),
ActionPlugin.ActionHandler(AlertingActions.DELETE_WORKFLOW_ACTION_TYPE, TransportDeleteWorkflowAction::class.java),
ActionPlugin.ActionHandler(AlertingActions.RUN_WORKFLOW_ACTION_TYPE, TransportRunWorkflowAction::class.java),
ActionPlugin.ActionHandler(ExecuteWorkflowAction.INSTANCE, TransportExecuteWorkflowAction::class.java)
)
}
Expand Down Expand Up @@ -346,7 +351,8 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R
AlertingSettings.FINDING_HISTORY_MAX_DOCS,
AlertingSettings.FINDING_HISTORY_INDEX_MAX_AGE,
AlertingSettings.FINDING_HISTORY_ROLLOVER_PERIOD,
AlertingSettings.FINDING_HISTORY_RETENTION_PERIOD
AlertingSettings.FINDING_HISTORY_RETENTION_PERIOD,
SharedSettings.STREAMING_SECURITY_ANALYTICS
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.opensearch.alerting.util.getCombinedTriggerRunResult
import org.opensearch.alerting.workflow.WorkflowRunContext
import org.opensearch.common.xcontent.LoggingDeprecationHandler
import org.opensearch.common.xcontent.XContentType
import org.opensearch.commons.alerting.action.IdDocPair
import org.opensearch.commons.alerting.model.Alert
import org.opensearch.commons.alerting.model.BucketLevelTrigger
import org.opensearch.commons.alerting.model.Finding
Expand Down Expand Up @@ -62,7 +63,8 @@ object BucketLevelMonitorRunner : MonitorRunner() {
periodEnd: Instant,
dryrun: Boolean,
workflowRunContext: WorkflowRunContext?,
executionId: String
executionId: String,
docs: List<IdDocPair>?
): MonitorRunResult<BucketLevelTriggerRunResult> {
val roles = MonitorRunnerService.getRolesForMonitor(monitor)
logger.debug("Running monitor: ${monitor.name} with roles: $roles Thread: ${Thread.currentThread().name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ import org.opensearch.cluster.routing.Preference
import org.opensearch.cluster.routing.ShardRouting
import org.opensearch.cluster.service.ClusterService
import org.opensearch.common.xcontent.XContentFactory
import org.opensearch.common.xcontent.XContentHelper
import org.opensearch.common.xcontent.XContentType
import org.opensearch.commons.alerting.AlertingPluginInterface
import org.opensearch.commons.alerting.action.IdDocPair
import org.opensearch.commons.alerting.action.PublishFindingsRequest
import org.opensearch.commons.alerting.action.SubscribeFindingsResponse
import org.opensearch.commons.alerting.model.ActionExecutionResult
Expand Down Expand Up @@ -75,7 +77,8 @@ object DocumentLevelMonitorRunner : MonitorRunner() {
periodEnd: Instant,
dryrun: Boolean,
workflowRunContext: WorkflowRunContext?,
executionId: String
executionId: String,
docs: List<IdDocPair>?
): MonitorRunResult<DocumentLevelTriggerRunResult> {
logger.debug("Document-level-monitor is running ...")
val isTempMonitor = dryrun || monitor.id == Monitor.NO_ID
Expand Down Expand Up @@ -219,14 +222,20 @@ object DocumentLevelMonitorRunner : MonitorRunner() {
// Prepare DocumentExecutionContext for each index
val docExecutionContext = DocumentExecutionContext(queries, indexLastRunContext, indexUpdatedRunContext)

val matchingDocs = getMatchingDocs(
val matchingDocs = if (docs == null) getMatchingDocs(
monitor,
monitorCtx,
docExecutionContext,
updatedIndexName,
concreteIndexName,
conflictingFields.toList(),
matchingDocIdsPerIndex?.get(concreteIndexName)
) else getMatchingDocs(
docs,
updatedIndexName,
concreteIndexName,
monitor.id,
conflictingFields.toList()
)

if (matchingDocs.isNotEmpty()) {
Expand Down Expand Up @@ -309,10 +318,12 @@ object DocumentLevelMonitorRunner : MonitorRunner() {
onSuccessfulMonitorRun(monitorCtx, monitor)
}

MonitorMetadataService.upsertMetadata(
monitorMetadata.copy(lastRunContext = updatedLastRunContext),
true
)
if (docs == null) {
MonitorMetadataService.upsertMetadata(
monitorMetadata.copy(lastRunContext = updatedLastRunContext),
true
)
}
}

// TODO: Update the Document as part of the Trigger and return back the trigger action result
Expand Down Expand Up @@ -598,6 +609,30 @@ object DocumentLevelMonitorRunner : MonitorRunner() {
return allShards.filter { it.primary() }.size
}

private fun getMatchingDocs(
docs: List<IdDocPair>,
index: String,
concreteIndex: String,
monitorId: String,
conflictingFields: List<String>,
): List<Pair<String, BytesReference>> {
return docs.map { createIdToDocPair(it, index, concreteIndex, monitorId, conflictingFields) }.toList()
}

private fun createIdToDocPair(
idDocPair: IdDocPair,
index: String,
concreteIndex: String,
monitorId: String,
conflictingFields: List<String>
): Pair<String, BytesReference> {
// TODO - uses deprecated method. Can we avoid the transformations?
val sourceAsMap = XContentHelper.convertToMap(idDocPair.document, false, XContentType.JSON)
val transformedDoc = transformDocument(sourceAsMap.v2(), index, concreteIndex, monitorId, conflictingFields)

return Pair(idDocPair.docId, transformedDoc)
}

private suspend fun getMatchingDocs(
monitor: Monitor,
monitorCtx: MonitorRunnerExecutionContext,
Expand Down Expand Up @@ -731,25 +766,40 @@ object DocumentLevelMonitorRunner : MonitorRunner() {
): List<Pair<String, BytesReference>> {
return hits.map { hit ->
val sourceMap = hit.sourceAsMap

transformDocumentFieldNames(
val sourceRef = transformDocument(
sourceMap,
conflictingFields,
"_${index}_$monitorId",
"_${concreteIndex}_$monitorId",
""
index,
concreteIndex,
monitorId,
conflictingFields
)

var xContentBuilder = XContentFactory.jsonBuilder().map(sourceMap)

val sourceRef = BytesReference.bytes(xContentBuilder)

logger.debug("Document [${hit.id}] payload after transform: ", sourceRef.utf8ToString())

Pair(hit.id, sourceRef)
}
}

private fun transformDocument(
sourceMap: MutableMap<String, Any>,
index: String,
concreteIndex: String,
monitorId: String,
conflictingFields: List<String>
): BytesReference {
transformDocumentFieldNames(
sourceMap,
conflictingFields,
"_${index}_$monitorId",
"_${concreteIndex}_$monitorId",
""
)

var xContentBuilder = XContentFactory.jsonBuilder().map(sourceMap)

return BytesReference.bytes(xContentBuilder)
}

/**
* Traverses document fields in leaves recursively and appends [fieldNameSuffixIndex] to field names with same names
* but different mappings & [fieldNameSuffixPattern] to field names which have unique names.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.opensearch.alerting.util.isTestAction
import org.opensearch.alerting.util.use
import org.opensearch.alerting.workflow.WorkflowRunContext
import org.opensearch.client.node.NodeClient
import org.opensearch.commons.alerting.action.IdDocPair
import org.opensearch.commons.alerting.model.Monitor
import org.opensearch.commons.alerting.model.Table
import org.opensearch.commons.alerting.model.action.Action
Expand All @@ -43,7 +44,8 @@ abstract class MonitorRunner {
periodEnd: Instant,
dryRun: Boolean,
workflowRunContext: WorkflowRunContext? = null,
executionId: String
executionId: String,
docs: List<IdDocPair>? = null
): MonitorRunResult<*>

suspend fun runAction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.opensearch.alerting.opensearchapi.withClosableContext
import org.opensearch.alerting.script.QueryLevelTriggerExecutionContext
import org.opensearch.alerting.util.isADMonitor
import org.opensearch.alerting.workflow.WorkflowRunContext
import org.opensearch.commons.alerting.action.IdDocPair
import org.opensearch.commons.alerting.model.Alert
import org.opensearch.commons.alerting.model.Monitor
import org.opensearch.commons.alerting.model.QueryLevelTrigger
Expand All @@ -28,7 +29,8 @@ object QueryLevelMonitorRunner : MonitorRunner() {
periodEnd: Instant,
dryrun: Boolean,
workflowRunContext: WorkflowRunContext?,
executionId: String
executionId: String,
docs: List<IdDocPair>?
): MonitorRunResult<QueryLevelTriggerRunResult> {
val roles = MonitorRunnerService.getRolesForMonitor(monitor)
logger.debug("Running monitor: ${monitor.name} with roles: $roles Thread: ${Thread.currentThread().name}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.opensearch.alerting.resthandler

import org.apache.logging.log4j.LogManager
import org.opensearch.alerting.AlertingPlugin
import org.opensearch.client.node.NodeClient
import org.opensearch.commons.alerting.action.AlertingActions
import org.opensearch.commons.alerting.action.RunWorkflowRequest
import org.opensearch.core.xcontent.XContentParser
import org.opensearch.core.xcontent.XContentParserUtils
import org.opensearch.rest.BaseRestHandler
import org.opensearch.rest.RestHandler
import org.opensearch.rest.RestRequest
import org.opensearch.rest.action.RestToXContentListener

class RestRunWorkflowAction : BaseRestHandler() {
private val log = LogManager.getLogger(javaClass)

override fun getName(): String {
return "run_workflow_action"
}

override fun routes(): List<RestHandler.Route> {
return listOf(
RestHandler.Route(
RestRequest.Method.POST,
"${AlertingPlugin.WORKFLOW_BASE_URI}/{workflowID}/run"
)
)
}

override fun prepareRequest(request: RestRequest, client: NodeClient): RestChannelConsumer {
log.debug("${request.method()} ${AlertingPlugin.WORKFLOW_BASE_URI}/{workflowID}/run")

val xcp = request.contentParser()
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp)
val runWorkflowRequest = RunWorkflowRequest.parse(xcp)

return RestChannelConsumer {
channel ->
client.execute(
AlertingActions.RUN_WORKFLOW_ACTION_TYPE,
runWorkflowRequest,
RestToXContentListener(channel)
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.alerting.transport

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import org.apache.logging.log4j.LogManager
import org.opensearch.ResourceNotFoundException
import org.opensearch.action.ActionRequest
import org.opensearch.action.support.ActionFilters
import org.opensearch.action.support.HandledTransportAction
import org.opensearch.alerting.MonitorRunnerService
import org.opensearch.alerting.opensearchapi.suspendUntil
import org.opensearch.alerting.util.AlertingException
import org.opensearch.alerting.workflow.CompositeWorkflowRunner
import org.opensearch.client.Client
import org.opensearch.common.inject.Inject
import org.opensearch.common.io.stream.BytesStreamOutput
import org.opensearch.commons.alerting.action.AlertingActions
import org.opensearch.commons.alerting.action.GetWorkflowRequest
import org.opensearch.commons.alerting.action.GetWorkflowResponse
import org.opensearch.commons.alerting.action.RunWorkflowRequest
import org.opensearch.commons.alerting.action.RunWorkflowResponse
import org.opensearch.core.action.ActionListener
import org.opensearch.core.rest.RestStatus
import org.opensearch.rest.RestRequest
import org.opensearch.tasks.Task
import org.opensearch.transport.TransportService
import java.time.Instant

private val log = LogManager.getLogger(TransportRunWorkflowAction::class.java)
private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)

class TransportRunWorkflowAction @Inject constructor(
transportService: TransportService,
val client: Client,
actionFilters: ActionFilters,
val transportGetWorkflowAction: TransportGetWorkflowAction
) : HandledTransportAction<ActionRequest, RunWorkflowResponse>(
AlertingActions.RUN_WORKFLOW_ACTION_NAME, transportService, actionFilters,
::RunWorkflowRequest
) {

override fun doExecute(task: Task, request: ActionRequest, actionListener: ActionListener<RunWorkflowResponse>) {
// TODO ser/de here is a hack to avoid a ClassCastException. Security Analytics and Alerting both use the RunWorkflowRequest object
// but have different ClassLoaders - https://stackoverflow.com/a/826345
val outputStream = BytesStreamOutput()
request.writeTo(outputStream)
val actualRequest = RunWorkflowRequest(outputStream.copyBytes().streamInput())

scope.launch {
val getWorkflowResponse: GetWorkflowResponse =
transportGetWorkflowAction.client.suspendUntil {
val getWorkflowRequest = GetWorkflowRequest(actualRequest.workflowId, RestRequest.Method.GET)
execute(AlertingActions.GET_WORKFLOW_ACTION_TYPE, getWorkflowRequest, it)
}

if (getWorkflowResponse.workflow != null) {
try {
// TODO - is using monitorCtx like this safe?
// TODO - is Instant.now() fine?
val workflowRunResult = CompositeWorkflowRunner.runWorkflow(
getWorkflowResponse.workflow!!,
MonitorRunnerService.monitorCtx,
Instant.now(),
Instant.now(),
false,
actualRequest.documents
)

if (workflowRunResult.error != null) {
actionListener.onFailure(workflowRunResult.error)
} else {
actionListener.onResponse(RunWorkflowResponse(RestStatus.OK))
}
} catch (e: Exception) {
actionListener.onFailure(
AlertingException.wrap(e)
)
}
} else {
actionListener.onFailure(
AlertingException.wrap(
ResourceNotFoundException("Workflow with id ${actualRequest.workflowId} not found", RestStatus.NOT_FOUND)
)
)
}
}
}
}
Loading

0 comments on commit 3565081

Please sign in to comment.