Skip to content

Commit

Permalink
Ensure kNN search respects authorization (#79693)
Browse files Browse the repository at this point in the history
This PR ensures the `_knn_search` endpoint handles both FLS and DLS:
* Updates `FieldSubsetReader` to handle FLS for the vectors format
* Adds tests to check both DLS and FLS work

Relates to #78473.
  • Loading branch information
jtibshirani committed Oct 26, 2021
1 parent 9326e5a commit 07428a1
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FilterIterator;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
Expand Down Expand Up @@ -275,6 +278,16 @@ public NumericDocValues getNormValues(String field) throws IOException {
return hasField(field) ? super.getNormValues(field) : null;
}

@Override
public VectorValues getVectorValues(String field) throws IOException {
return hasField(field) ? super.getVectorValues(field) : null;
}

@Override
public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs) throws IOException {
return hasField(field) ? super.searchNearestVectors(field, target, k, acceptDocs) : null;
}

// we share core cache keys (for e.g. fielddata)

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
Expand Down Expand Up @@ -42,6 +43,8 @@
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.TermsEnum.SeekStatus;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestUtil;
Expand All @@ -58,11 +61,11 @@
import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.index.mapper.FieldNamesFieldMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.security.authz.permission.FieldPermissions;
import org.elasticsearch.xpack.core.security.authz.permission.FieldPermissionsDefinition;
import org.elasticsearch.xpack.core.security.support.Automatons;
Expand Down Expand Up @@ -180,6 +183,38 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
IOUtils.close(ir, iw, dir);
}

public void testKnnVectors() throws Exception {
Directory dir = newDirectory();
IndexWriterConfig iwc = new IndexWriterConfig(null);
IndexWriter iw = new IndexWriter(dir, iwc);

Document doc = new Document();
doc.add(new KnnVectorField("fieldA", new float[] {0.1f, 0.2f, 0.3f}));
doc.add(new KnnVectorField("fieldB", new float[] {3.0f, 2.0f, 1.0f}));
iw.addDocument(doc);

DirectoryReader ir = FieldSubsetReader.wrap(DirectoryReader.open(iw), new CharacterRunAutomaton(Automata.makeString("fieldA")));
LeafReader leafReader = ir.leaves().get(0).reader();

// Check that fieldA behaves as normal
VectorValues vectorValues = leafReader.getVectorValues("fieldA");
assertEquals(3, vectorValues.dimension());
assertEquals(1, vectorValues.size());
assertEquals(0, vectorValues.nextDoc());
assertNotNull(vectorValues.binaryValue());

TopDocs topDocs = leafReader.searchNearestVectors("fieldA", new float[] {1.0f, 1.0f, 1.0f}, 5, null);
assertNotNull(topDocs);
assertEquals(1, topDocs.scoreDocs.length);

// Check that we can't see fieldB
assertNull(leafReader.getVectorValues("fieldB"));
assertNull(leafReader.searchNearestVectors("fieldB", new float[] {1.0f, 1.0f, 1.0f}, 5, null));

TestUtil.checkReader(ir);
IOUtils.close(ir, iw, dir);
}

/**
* test filtering two stored fields (string)
*/
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {

testImplementation project(path: xpackModule('monitoring'))
testImplementation project(path: xpackModule('spatial'))
testImplementation project(path: xpackModule('vectors'))
testImplementation project(path: ':modules:legacy-geo')
testImplementation project(path: ':modules:percolator')
testImplementation project(path: xpackModule('sql:sql-action'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
import org.elasticsearch.client.Requests;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.geo.ShapeRelation;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.FuzzyQueryBuilder;
import org.elasticsearch.index.query.InnerHitBuilder;
Expand All @@ -45,6 +44,7 @@
import org.elasticsearch.percolator.PercolatorPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.global.Global;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
Expand All @@ -65,10 +65,15 @@
import org.elasticsearch.test.InternalSettingsPlugin;
import org.elasticsearch.test.SecurityIntegTestCase;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.security.LocalStateSecurity;
import org.elasticsearch.xpack.spatial.SpatialPlugin;
import org.elasticsearch.xpack.spatial.index.query.ShapeQueryBuilder;
import org.elasticsearch.xpack.vectors.DenseVectorPlugin;
import org.elasticsearch.xpack.vectors.query.KnnVectorQueryBuilder;

import java.util.Arrays;
import java.util.Collection;
Expand All @@ -79,7 +84,6 @@

import static java.util.stream.Collectors.toList;
import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.index.query.QueryBuilders.termQuery;
import static org.elasticsearch.integration.FieldLevelSecurityTests.openPointInTime;
Expand All @@ -90,6 +94,7 @@
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchHits;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.BASIC_AUTH_HEADER;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.equalTo;
Expand All @@ -105,7 +110,7 @@ public class DocumentLevelSecurityTests extends SecurityIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateSecurity.class, CommonAnalysisPlugin.class, ParentJoinPlugin.class,
InternalSettingsPlugin.class, SpatialPlugin.class, PercolatorPlugin.class);
InternalSettingsPlugin.class, DenseVectorPlugin.class, SpatialPlugin.class, PercolatorPlugin.class);
}

@Override
Expand Down Expand Up @@ -853,6 +858,77 @@ public void testMTVApi() throws Exception {
assertThat(response.getResponses()[0].getResponse().isExists(), is(false));
}

public void testKnnSearch() throws Exception {
Settings indexSettings = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.build();
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 3)
.field("index", true)
.field("similarity", "l2_norm")
.endObject()
.endObject().endObject();
assertAcked(client().admin().indices().prepareCreate("test")
.setSettings(indexSettings)
.setMapping(builder));

for (int i = 0; i < 5; i++) {
client().prepareIndex("test")
.setSource("field1", "value1", "vector", new float[]{i, i, i})
.get();
client().prepareIndex("test")
.setSource("field2", "value2", "vector", new float[]{i, i, i})
.get();
}

client().admin().indices().prepareRefresh("test").get();

// Since there's no kNN search action at the transport layer, we just emulate
// how the action works (it builds a kNN query under the hood)
float[] queryVector = new float[]{0.0f, 0.0f, 0.0f};
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50);

// user1 should only be able to see docs with field1: value1
SearchResponse response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(query)
.addFetchField("field1")
.setSize(10)
.get();
assertEquals(5, response.getHits().getTotalHits().value);
assertEquals(5, response.getHits().getHits().length);
for (SearchHit hit : response.getHits().getHits()) {
assertNotNull(hit.field("field1"));
}

// user2 should only be able to see docs with field2: value2
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(query)
.addFetchField("field2")
.setSize(10)
.get();
assertEquals(5, response.getHits().getTotalHits().value);
assertEquals(5, response.getHits().getHits().length);
for (SearchHit hit : response.getHits().getHits()) {
assertNotNull(hit.field("field2"));
}

// user3 can see all indexed docs
response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user3", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(query)
.setSize(10)
.get();
assertEquals(10, response.getHits().getTotalHits().value);
assertEquals(10, response.getHits().getHits().length);
}

public void testGlobalAggregation() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test")
.setMapping("field1", "type=text", "field2", "type=text,fielddata=true", "field3", "type=text")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@
import org.elasticsearch.xpack.security.LocalStateSecurity;
import org.elasticsearch.xpack.spatial.SpatialPlugin;
import org.elasticsearch.xpack.spatial.index.query.ShapeQueryBuilder;
import org.elasticsearch.xpack.vectors.DenseVectorPlugin;
import org.elasticsearch.xpack.vectors.query.KnnVectorQueryBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -90,7 +93,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateSecurity.class, CommonAnalysisPlugin.class, ParentJoinPlugin.class,
InternalSettingsPlugin.class, PercolatorPlugin.class, SpatialPlugin.class);
InternalSettingsPlugin.class, PercolatorPlugin.class, DenseVectorPlugin.class, SpatialPlugin.class);
}

@Override
Expand Down Expand Up @@ -140,7 +143,7 @@ protected String configRoles() {
" - names: '*'\n" +
" privileges: [ ALL ]\n" +
" field_security:\n" +
" grant: [ field1, join_field* ]\n" +
" grant: [ field1, join_field*, vector ]\n" +
"role3:\n" +
" cluster: [ all ]\n" +
" indices:\n" +
Expand Down Expand Up @@ -351,6 +354,54 @@ public void testQuery() {
assertHitCount(response, 0);
}

public void testKnnSearch() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 3)
.field("index", true)
.field("similarity", "l2_norm")
.endObject()
.endObject().endObject();
assertAcked(client().admin().indices().prepareCreate("test").setMapping(builder));

client().prepareIndex("test")
.setSource("field1", "value1", "vector", new float[]{0.0f, 0.0f, 0.0f})
.setRefreshPolicy(IMMEDIATE)
.get();

// Since there's no kNN search action at the transport layer, we just emulate
// how the action works (it builds a kNN query under the hood)
float[] queryVector = new float[]{0.0f, 0.0f, 0.0f};
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10);

// user1 has access to vector field, so the query should match with the document:
SearchResponse response = client()
.filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(query)
.addFetchField("vector")
.get();
assertHitCount(response, 1);
assertNotNull(response.getHits().getAt(0).field("vector"));

// user2 has no access to vector field, so the query should not match with the document:
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(query)
.addFetchField("vector")
.get();
assertHitCount(response, 0);

// check user2 cannot see the vector field, even when their search matches the document
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user2", USERS_PASSWD)))
.prepareSearch("test")
.addFetchField("vector")
.get();
assertHitCount(response, 1);
assertNull(response.getHits().getAt(0).field("vector"));
}

public void testPercolateQueryWithIndexedDocWithFLS() {
assertAcked(client().admin().indices().prepareCreate("query_index")
.setMapping("query", "type=percolator", "field2", "type=text")
Expand Down

0 comments on commit 07428a1

Please sign in to comment.