Skip to content

Commit

Permalink
[ML] Make semantic search an indices action (#90887)
Browse files Browse the repository at this point in the history
Users do not require an ml privilege to call _semantic_search
  • Loading branch information
davidkyle committed Oct 31, 2022
1 parent 0799663 commit 1d0a309
Show file tree
Hide file tree
Showing 12 changed files with 383 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -43,17 +46,15 @@

public class SemanticSearchAction extends ActionType<SemanticSearchAction.Response> {

// TODO what should this be called? If this becomes an indices action
// change the transport code to run _infer as ML_ORIGIN
public static final String NAME = "cluster:monitor/xpack/ml/semantic_search";
public static final String NAME = "indices:data/read/semantic_search";

public static final SemanticSearchAction INSTANCE = new SemanticSearchAction(NAME);

private SemanticSearchAction(String name) {
super(name, SemanticSearchAction.Response::new);
}

public static class Request extends ActionRequest {
public static class Request extends ActionRequest implements IndicesRequest.Replaceable {

public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed

Expand Down Expand Up @@ -111,7 +112,7 @@ public static Request parseRestRequest(RestRequest restRequest) throws IOExcepti
return builder.build();
}

private final String[] indices;
private String[] indices;
private final String routing;
private final String queryString;
private final String modelId;
Expand Down Expand Up @@ -158,7 +159,7 @@ public Request(StreamInput in) throws IOException {
List<FieldAndFormat> docValueFields,
StoredFieldsContext storedFields
) {
this.indices = indices;
this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
this.routing = routing;
this.queryString = queryString;
this.modelId = modelId;
Expand Down Expand Up @@ -194,10 +195,26 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(storedFields);
}

public String[] getIndices() {
@Override
public String[] indices() {
return indices;
}

@Override
public IndicesOptions indicesOptions() {
return SearchRequest.DEFAULT_INDICES_OPTIONS;
}

@Override
public IndicesRequest indices(String... indices) {
Objects.requireNonNull(indices, "indices must not be null");
for (String index : indices) {
Objects.requireNonNull(index, "index must not be null");
}
this.indices = indices;
return this;
}

public String getRouting() {
return routing;
}
Expand Down Expand Up @@ -312,7 +329,7 @@ public static class Builder {
private StoredFieldsContext storedFields;

Builder(String[] indices) {
this.indices = indices;
this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
}

void setRouting(String routing) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,20 @@ public void testValidate() {
var validAction = createTestInstance();
assertNull(validAction.validate());

var action = new SemanticSearchAction.Request(null, null, null, null, null, null, null, null, null, null, null, null);
var action = new SemanticSearchAction.Request(
new String[] { "foo" },
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null
);
var validation = action.validate();
assertNotNull(validation);
assertThat(validation.validationErrors(), hasSize(3));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.update.UpdateAction;
import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.SemanticSearchAction;
import org.elasticsearch.xpack.core.rollup.action.GetRollupIndexCapsAction;
import org.elasticsearch.xpack.core.transform.action.GetCheckpointAction;

Expand Down Expand Up @@ -80,6 +81,10 @@ public void testPrivilegesForGetCheckPointAction() {
);
}

public void testPrivilegesForSemanticSearchAction() {
assertThat(findPrivilegesThatGrant(SemanticSearchAction.NAME), equalTo(List.of("read", "all")));
}

public void testRelationshipBetweenPrivileges() {
assertThat(
Operations.subsetOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ protected Settings restAdminSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

@Override
protected Map<String, String> getApiCallHeaders() {
return Collections.singletonMap(
"Authorization",
Expand Down
23 changes: 23 additions & 0 deletions x-pack/plugin/ml/qa/semantic-search-tests/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
apply plugin: 'elasticsearch.internal-yaml-rest-test'

dependencies {
yamlRestTestImplementation(testArtifact(project(xpackModule('core'))))
yamlRestTestImplementation(testArtifact(project(':x-pack:plugin')))
}

// bring in machine learning rest test suite
restResources {
restApi {
include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'ml', 'semantic_search'
}
}

testClusters.configureEach {
testDistribution = 'DEFAULT'
rolesFile file('roles.yml')
user username: "x_pack_rest_user", password: "x-pack-test-password"
user username: "read_index_no_ml", password: "read_index_no_ml_password", role: "all_data"
user username: "no_read_index_no_ml", password: "no_read_index_no_ml_password", role: "unrelated_index_only"
setting 'xpack.license.self_generated.type', 'trial'
setting 'xpack.security.enabled', 'true'
}
29 changes: 29 additions & 0 deletions x-pack/plugin/ml/qa/semantic-search-tests/roles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
all_data:
cluster:
# This is always required because the REST client uses it to find the version of
# Elasticsearch it's talking to
- cluster:monitor/main
indices:
# User
- names: [ 'embedded_text', 'unrelated' ]
privileges:
- create_index
- indices:admin/refresh
- read
- write
- view_index_metadata

unrelated_index_only:
cluster:
# This is always required because the REST client uses it to find the version of
# Elasticsearch it's talking to
- cluster:monitor/main
indices:
#
- names: [ 'unrelated' ]
privileges:
- create_index
- indices:admin/refresh
- read
- write
- view_index_metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.smoketest;

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.xpack.test.rest.AbstractXPackRestTest;

import java.util.Collections;
import java.util.Map;

public abstract class AbstractSemanticSearchPermissionsIT extends AbstractXPackRestTest {

private static final String TEST_ADMIN_USERNAME = "x_pack_rest_user";

public AbstractSemanticSearchPermissionsIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}

protected abstract String[] getCredentials();

@Override
protected Settings restClientSettings() {
String[] creds = getCredentials();
String token = basicAuthHeaderValue(creds[0], new SecureString(creds[1].toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

@Override
protected Settings restAdminSettings() {
String token = basicAuthHeaderValue(TEST_ADMIN_USERNAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

@Override
protected Map<String, String> getApiCallHeaders() {
return Collections.singletonMap(
"Authorization",
basicAuthHeaderValue(TEST_ADMIN_USERNAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)
);
}

@Override
protected boolean isMachineLearningTest() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.smoketest;

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
import org.elasticsearch.test.rest.yaml.section.DoSection;
import org.elasticsearch.test.rest.yaml.section.ExecutableSection;

import java.io.IOException;

import static org.hamcrest.Matchers.containsString;

public class SemanticSearchNoReadPermissionsIT extends AbstractSemanticSearchPermissionsIT {

private static final String USERNAME = "no_read_index_no_ml";

private final ClientYamlTestCandidate testCandidate;

public SemanticSearchNoReadPermissionsIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
this.testCandidate = testCandidate;
}

@Override
protected String[] getCredentials() {
return new String[] { USERNAME, "no_read_index_no_ml_password" };
}

@Override
public void test() throws IOException {
try {
// Cannot use expectThrows here because blacklisted tests will throw an
// InternalAssumptionViolatedException rather than an AssertionError
super.test();

for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) {
if (section instanceof DoSection doSection) {
String apiName = doSection.getApiCallSection().getApi();
fail("call to semantic_search endpoint [" + apiName + "] should have failed because of missing role");
}
}
} catch (AssertionError ae) {
if (ae.getMessage().startsWith("call to")) {
// rethrow the fail() from the try section above
throw ae;
}
assertThat(ae.getMessage(), containsString("security_exception"));
assertThat(ae.getMessage(), containsString("returned [403 Forbidden]"));
assertThat(ae.getMessage(), containsString("is unauthorized for user [" + USERNAME + "]"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.smoketest;

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;

public class SemanticSearchReadPermissionsIT extends AbstractSemanticSearchPermissionsIT {

public SemanticSearchReadPermissionsIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
super(testCandidate);
}

@Override
protected String[] getCredentials() {
return new String[] { "read_index_no_ml", "read_index_no_ml_password" };
}
}

0 comments on commit 1d0a309

Please sign in to comment.