diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java index 91244a1ad36b9..6bbe5ac69f05e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java @@ -13,6 +13,9 @@ import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.delete.DeleteResponse; import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.get.MultiGetItemResponse; +import org.elasticsearch.action.get.MultiGetRequest; +import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.search.MultiSearchResponse; @@ -49,7 +52,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -113,33 +115,54 @@ public void getRoleDescriptors(Set names, final ActionListener { - QueryBuilder query; - if (names == null || names.isEmpty()) { - query = QueryBuilders.termQuery(RoleDescriptor.Fields.TYPE.getPreferredName(), ROLE_TYPE); - } else { - final String[] roleNames = names.stream().map(NativeRolesStore::getIdForUser).toArray(String[]::new); - query = QueryBuilders.boolQuery().filter(QueryBuilders.idsQuery(ROLE_DOC_TYPE).addIds(roleNames)); - } + QueryBuilder query = QueryBuilders.termQuery(RoleDescriptor.Fields.TYPE.getPreferredName(), ROLE_TYPE); final Supplier supplier = client.threadPool().getThreadContext().newRestorableContext(false); try (ThreadContext.StoredContext ignore = stashWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN)) { SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) - .setScroll(DEFAULT_KEEPALIVE_SETTING.get(settings)) - .setQuery(query) - .setSize(1000) - .setFetchSource(true) - .request(); + .setScroll(DEFAULT_KEEPALIVE_SETTING.get(settings)) + .setQuery(query) + .setSize(1000) + .setFetchSource(true) + .request(); request.indicesOptions().ignoreUnavailable(); - final ActionListener> descriptorsListener = ActionListener.wrap( - roleDescriptors -> listener.onResponse(RoleRetrievalResult.success(new HashSet<>(roleDescriptors))), - e -> listener.onResponse(RoleRetrievalResult.failure(e))); - ScrollHelper.fetchAllByEntity(client, request, new ContextPreservingActionListener<>(supplier, descriptorsListener), - (hit) -> transformRole(hit.getId(), hit.getSourceRef(), logger, licenseState)); + ScrollHelper.fetchAllByEntity(client, request, new ContextPreservingActionListener<>(supplier, + ActionListener.wrap(roles -> listener.onResponse(RoleRetrievalResult.success(new HashSet<>(roles))), + e -> listener.onResponse(RoleRetrievalResult.failure(e)))), + (hit) -> transformRole(hit.getId(), hit.getSourceRef(), logger, licenseState)); } }); + } else if (names.size() == 1) { + getRoleDescriptor(Objects.requireNonNull(names.iterator().next()), listener); + } else { + securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> { + final String[] roleIds = names.stream().map(NativeRolesStore::getIdForRole).toArray(String[]::new); + MultiGetRequest multiGetRequest = client.prepareMultiGet().add(SECURITY_INDEX_NAME, ROLE_DOC_TYPE, roleIds).request(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, multiGetRequest, + ActionListener.wrap(mGetResponse -> { + final MultiGetItemResponse[] responses = mGetResponse.getResponses(); + Set descriptors = new HashSet<>(); + for (int i = 0; i < responses.length; i++) { + MultiGetItemResponse item = responses[i]; + if (item.isFailed()) { + final Exception failure = item.getFailure().getFailure(); + for (int j = i + 1; j < responses.length; j++) { + item = responses[j]; + if (item.isFailed()) { + failure.addSuppressed(failure); + } + } + listener.onResponse(RoleRetrievalResult.failure(failure)); + return; + } else if (item.getResponse().isExists()) { + descriptors.add(transformRole(item.getResponse())); + } + } + listener.onResponse(RoleRetrievalResult.success(descriptors)); + }, + e -> listener.onResponse(RoleRetrievalResult.failure(e))), client::multiGet); + }); } } @@ -152,7 +175,7 @@ public void deleteRole(final DeleteRoleRequest deleteRoleRequest, final ActionLi } else { securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> { DeleteRequest request = client.prepareDelete(SecurityIndexManager.SECURITY_INDEX_NAME, - ROLE_DOC_TYPE, getIdForUser(deleteRoleRequest.name())).request(); + ROLE_DOC_TYPE, getIdForRole(deleteRoleRequest.name())).request(); request.setRefreshPolicy(deleteRoleRequest.getRefreshPolicy()); executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, new ActionListener() { @@ -192,7 +215,7 @@ void innerPutRole(final PutRoleRequest request, final RoleDescriptor role, final listener.onFailure(e); return; } - final IndexRequest indexRequest = client.prepareIndex(SECURITY_INDEX_NAME, ROLE_DOC_TYPE, getIdForUser(role.getName())) + final IndexRequest indexRequest = client.prepareIndex(SECURITY_INDEX_NAME, ROLE_DOC_TYPE, getIdForRole(role.getName())) .setSource(xContentBuilder) .setRefreshPolicy(request.getRefreshPolicy()) .request(); @@ -308,7 +331,7 @@ private void executeGetRoleRequest(String role, ActionListener list securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, client.prepareGet(SECURITY_INDEX_NAME, - ROLE_DOC_TYPE, getIdForUser(role)).request(), + ROLE_DOC_TYPE, getIdForRole(role)).request(), listener, client::get)); } @@ -388,7 +411,7 @@ public static void addSettings(List> settings) { /** * Gets the document's id field for the given role name. */ - private static String getIdForUser(final String roleName) { + private static String getIdForRole(final String roleName) { return ROLE_TYPE + "-" + roleName; } }