Skip to content

Commit

Permalink
initial default use case addition
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <amgalitz@amazon.com>
  • Loading branch information
amitgalitz committed Mar 15, 2024
1 parent 1f6573d commit c679732
Show file tree
Hide file tree
Showing 16 changed files with 820 additions and 18 deletions.
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ dependencies {

// ZipArchive dependencies used for integration tests
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${opensearch_build}"

secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}"

configurations.all {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ private CommonValue() {}
public static final String PROVISION_WORKFLOW = "provision";
/** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */
public static final String WORKFLOW_STEP = "workflow_step";
/** The param name for default use case, used by the create workflow API */
public static final String USE_CASE = "use_case";

/*
* Constants associated with plugin configuration
Expand Down
120 changes: 120 additions & 0 deletions src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.common;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Enum encapsulating the different default use cases and templates we have stored
*/
public enum DefaultUseCases {

/** defaults file and substitution ready template for OpenAI embedding model */
OPEN_AI_EMBEDDING_MODEL_DEPLOY(
"open_ai_embedding_model_deploy",
"defaults/open-ai-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-template.json"
),
/** defaults file and substitution ready template for cohere embedding model */
COHERE_EMBEDDING_MODEL_DEPLOY(
"cohere-embedding_model_deploy",
"defaults/cohere-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-template-extra-params.json"
),
LOCAL_NEURAL_SPARSE_SEARCH(
"local_neural_sparse_search",
"defaults/local-sparse-search-defaults.json",
"substitutionTemplates/neural-sparse-local-template.json"
),
/** defaults file and substitution ready template for cohere embedding model */ // TODO: not finalized
COHERE_EMBEDDING_MODEL_DEPLOY_SEMANTIC_SEARCH(
"cohere-embedding_model_deploy_semantic_search",
"defaults/cohere-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-template-v1.json"
);

private final String useCaseName;
private final String defaultsFile;
private final String substitutionReadyFile;
private static final Logger logger = LogManager.getLogger(DefaultUseCases.class);
private static final Set<String> allResources = Stream.of(values()).map(DefaultUseCases::getDefaultsFile).collect(Collectors.toSet());

DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile) {
this.useCaseName = useCaseName;
this.defaultsFile = defaultsFile;
this.substitutionReadyFile = substitutionReadyFile;
}

/**
* Returns the useCaseName for the given enum Constant
* @return the useCaseName of this use case.
*/
public String getUseCaseName() {
return useCaseName;
}

/**
* Returns the defaultsFile for the given enum Constant
* @return the defaultsFile of this for the given useCase.
*/
public String getDefaultsFile() {
return defaultsFile;
}

/**
* Returns the substitutionReadyFile for the given enum Constant
* @return the substitutionReadyFile of the given useCase
*/
public String getSubstitutionReadyFile() {
return substitutionReadyFile;
}

/**
* Gets the defaultsFile based on the given use case.
* @param useCaseName name of the given use case
* @return the deafultsFile for that usecase
* @throws FlowFrameworkException if the use case doesn't exist in enum
*/
public static String getDefaultsFileByUseCaseName(String useCaseName) throws FlowFrameworkException {
if (useCaseName != null && !useCaseName.isEmpty()) {
for (DefaultUseCases mapping : values()) {
if (useCaseName.equals(mapping.getUseCaseName())) {
return mapping.getDefaultsFile();
}
}
}
logger.error("Unable to find defaults file for use case: {}", useCaseName);
throw new FlowFrameworkException("Unable to find defaults file for use case: " + useCaseName, RestStatus.BAD_REQUEST);
}

/**
* Gets the substitutionReadyFile based on the given use case
* @param useCaseName name of the given use case
* @return the substitutionReadyFile which has the template
* @throws FlowFrameworkException if the use case doesn't exist in enum
*/
public static String getSubstitutionReadyFileByUseCaseName(String useCaseName) throws FlowFrameworkException {
if (useCaseName != null && !useCaseName.isEmpty()) {
for (DefaultUseCases mapping : values()) {
if (mapping.getUseCaseName().equals(useCaseName)) {
return mapping.getSubstitutionReadyFile();
}
}
}
logger.error("Unable to find substitution ready file for use case: {}", useCaseName);
throw new FlowFrameworkException("Unable to find substitution ready file for use case: " + useCaseName, RestStatus.BAD_REQUEST);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.DefaultUseCases;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.transport.CreateWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -35,6 +38,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
import static org.opensearch.flowframework.common.CommonValue.VALIDATION;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
Expand Down Expand Up @@ -78,6 +82,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
String workflowId = request.param(WORKFLOW_ID);
String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" });
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
String useCase = request.param(USE_CASE);
// If provisioning, consume all other params and pass to provision transport action
Map<String, String> params = provision
? request.params()
Expand Down Expand Up @@ -112,11 +117,54 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
);
}
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Template template = Template.parse(parser);

WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, params);
Template template;
Map<String, String> useCaseDefaultsMap = Collections.emptyMap();
if (useCase != null) {
String json = ParseUtils.resourceToString("/" + DefaultUseCases.getSubstitutionReadyFileByUseCaseName(useCase));
String defaultsFilePath = DefaultUseCases.getDefaultsFileByUseCaseName(useCase);
useCaseDefaultsMap = ParseUtils.parseJsonFileToStringToStringMap("/" + defaultsFilePath);

if (request.hasContent()) {
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Map<String, String> userDefaults = ParseUtils.parseStringToStringMap(parser);
// updates the default params with anything user has given that matches
for (Map.Entry<String, String> userDefaultsEntry : userDefaults.entrySet()) {
if (useCaseDefaultsMap.containsKey(userDefaultsEntry.getKey())) {
useCaseDefaultsMap.put(userDefaultsEntry.getKey(), userDefaultsEntry.getValue());
}
}
} catch (Exception ex) {
String errorMessage = "failure parsing request body when a use case is given";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, ExceptionsHelper.status(ex));
}

}

json = (String) ParseUtils.conditionallySubstitute(json, null, useCaseDefaultsMap);

XContentParser parserTestJson = ParseUtils.jsonToParser(json);
ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson);
template = Template.parse(parserTestJson);

} else {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
template = Template.parse(parser);
}

WorkflowRequest workflowRequest = new WorkflowRequest(
workflowId,
template,
validation,
provision,
params,
useCase,
useCaseDefaultsMap
);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
Expand All @@ -134,11 +182,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), errorMessage));
}
}));

} catch (FlowFrameworkException e) {
logger.error("failed to prepare rest request", e);
return channel -> channel.sendResponse(
new BytesRestResponse(e.getRestStatus(), e.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
} catch (IOException e) {
} catch (Exception e) {
logger.error("failed to prepare rest request", e);
FlowFrameworkException ex = new FlowFrameworkException(
"IOException: template content invalid for specified Content-Type.",
RestStatus.BAD_REQUEST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,23 @@ public class WorkflowRequest extends ActionRequest {
*/
private Map<String, String> params;

/**
* use case flag
*/
private String useCase;

/**
* Deafult params map from use case
*/
private Map<String, String> defaultParams;

/**
* Instantiates a new WorkflowRequest, set validation to all, no provisioning
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) {
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap());
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), null, Collections.emptyMap());
}

/**
Expand All @@ -65,7 +75,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
* @param params The parameters from the REST path
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map<String, String> params) {
this(workflowId, template, new String[] { "all" }, true, params);
this(workflowId, template, new String[] { "all" }, true, params, null, Collections.emptyMap());
}

/**
Expand All @@ -75,13 +85,17 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template,
* @param validation flag to indicate if validation is necessary
* @param provision flag to indicate if provision is necessary
* @param params map of REST path params. If provision is false, must be an empty map.
* @param useCase default use case given
* @param defaultParams the params to be used in the substitution based on the default use case.
*/
public WorkflowRequest(
@Nullable String workflowId,
@Nullable Template template,
String[] validation,
boolean provision,
Map<String, String> params
Map<String, String> params,
String useCase,
Map<String, String> defaultParams
) {
this.workflowId = workflowId;
this.template = template;
Expand All @@ -91,6 +105,8 @@ public WorkflowRequest(
throw new IllegalArgumentException("Params may only be included when provisioning.");
}
this.params = params;
this.useCase = useCase;
this.defaultParams = defaultParams;
}

/**
Expand Down Expand Up @@ -150,6 +166,22 @@ public Map<String, String> getParams() {
return Map.copyOf(this.params);
}

/**
* Gets the params map
* @return the params map
*/
public String getUseCase() {
return this.useCase;
}

/**
* Gets the params map
* @return the params map
*/
public Map<String, String> getDefaultParams() {
return Map.copyOf(this.defaultParams);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand Down
Loading

0 comments on commit c679732

Please sign in to comment.