Skip to content

Commit

Permalink
Short circuit authorization for child actions (#77221) (#78295)
Browse files Browse the repository at this point in the history
This commit detects a specific case when a child action (e.g. a shard
level action, or a phased action) acts on the same indices (or a
subset of the indices) or that parent request, and we can retain the
original authorization result.

The optimization is only effective for the invocation of the child
action on the same node as the parent - if the transport action needs
to be executed on a remote node then that authorization will not be
optimized and will perform the full check as existed before this
change.

This change is primarily benefitial for actions where a single parent
action on a coordinating node triggers the execution of multiple
children (e.g. a child action per shard) as it allows the
coordinating node to trigger those action and allow the load
to be passed to the remote nodes as quickly as possible rather than
having authorization on the coordinating node become a bottleneck.

Co-authored-by: Tim Vernum <tim.vernum@elastic.co>
  • Loading branch information
ywangd and tvernum committed Sep 25, 2021
1 parent cee5dc8 commit c05c023
Show file tree
Hide file tree
Showing 11 changed files with 511 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
import static org.hamcrest.Matchers.both;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

Expand Down Expand Up @@ -89,7 +91,7 @@ public void testThatBulkProcessorCountIsCorrect() throws Exception {

assertThat(listener.beforeCounts.get(), equalTo(1));
assertThat(listener.afterCounts.get(), equalTo(1));
assertThat(listener.bulkFailures.size(), equalTo(0));
assertThat(listener.bulkFailures, empty());
assertResponseItems(listener.bulkItems, numDocs);
assertMultiGetResponse(highLevelClient().mget(multiGetRequest, RequestOptions.DEFAULT), numDocs);
}
Expand All @@ -115,7 +117,7 @@ public void testBulkProcessorFlush() throws Exception {

assertThat(listener.beforeCounts.get(), equalTo(1));
assertThat(listener.afterCounts.get(), equalTo(1));
assertThat(listener.bulkFailures.size(), equalTo(0));
assertThat(listener.bulkFailures, empty());
assertResponseItems(listener.bulkItems, numDocs);
assertMultiGetResponse(highLevelClient().mget(multiGetRequest, RequestOptions.DEFAULT), numDocs);
}
Expand Down Expand Up @@ -147,16 +149,16 @@ public void testBulkProcessorConcurrentRequests() throws Exception {

assertThat(listener.beforeCounts.get(), equalTo(expectedBulkActions));
assertThat(listener.afterCounts.get(), equalTo(expectedBulkActions));
assertThat(listener.bulkFailures.size(), equalTo(0));
assertThat(listener.bulkItems.size(), equalTo(numDocs - numDocs % bulkActions));
assertThat(listener.bulkFailures, empty());
assertThat(listener.bulkItems, hasSize(numDocs - numDocs % bulkActions));
}

closeLatch.await();

assertThat(listener.beforeCounts.get(), equalTo(totalExpectedBulkActions));
assertThat(listener.afterCounts.get(), equalTo(totalExpectedBulkActions));
assertThat(listener.bulkFailures.size(), equalTo(0));
assertThat(listener.bulkItems.size(), equalTo(numDocs));
assertThat(listener.bulkFailures, empty());
assertThat(listener.bulkItems, hasSize(numDocs));

Set<String> ids = new HashSet<>();
for (BulkItemResponse bulkItemResponse : listener.bulkItems) {
Expand Down Expand Up @@ -198,7 +200,7 @@ public void testBulkProcessorWaitOnClose() throws Exception {
for (Throwable bulkFailure : listener.bulkFailures) {
logger.error("bulk failure", bulkFailure);
}
assertThat(listener.bulkFailures.size(), equalTo(0));
assertThat(listener.bulkFailures, empty());
assertResponseItems(listener.bulkItems, numDocs);
assertMultiGetResponse(highLevelClient().mget(multiGetRequest, RequestOptions.DEFAULT), numDocs);
}
Expand Down Expand Up @@ -255,8 +257,8 @@ public void testBulkProcessorConcurrentRequestsReadOnlyIndex() throws Exception

assertThat(listener.beforeCounts.get(), equalTo(totalExpectedBulkActions));
assertThat(listener.afterCounts.get(), equalTo(totalExpectedBulkActions));
assertThat(listener.bulkFailures.size(), equalTo(0));
assertThat(listener.bulkItems.size(), equalTo(testDocs + testReadOnlyDocs));
assertThat(listener.bulkFailures, empty());
assertThat(listener.bulkItems, hasSize(testDocs + testReadOnlyDocs));

Set<String> ids = new HashSet<>();
Set<String> readOnlyIds = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void testAuthorizeRunAs() {
Authentication authentication =
new Authentication(new User("joe", new String[]{"custom_superuser"}, new User("bar", "not_superuser")),
new RealmRef("test", "test", "node"), new RealmRef("test", "test", "node"));
RequestInfo info = new RequestInfo(authentication, request, action);
RequestInfo info = new RequestInfo(authentication, request, action, null);
PlainActionFuture<AuthorizationInfo> future = new PlainActionFuture<>();
engine.resolveAuthorizationInfo(info, future);
AuthorizationInfo authzInfo = future.actionGet();
Expand All @@ -72,7 +72,7 @@ public void testAuthorizeRunAs() {
Authentication authentication =
new Authentication(new User("joe", new String[]{"not_superuser"}, new User("bar", "custom_superuser")),
new RealmRef("test", "test", "node"), new RealmRef("test", "test", "node"));
RequestInfo info = new RequestInfo(authentication, request, action);
RequestInfo info = new RequestInfo(authentication, request, action, null);
PlainActionFuture<AuthorizationInfo> future = new PlainActionFuture<>();
engine.resolveAuthorizationInfo(info, future);
AuthorizationInfo authzInfo = future.actionGet();
Expand Down Expand Up @@ -104,7 +104,7 @@ public void testAuthorizeClusterAction() {
{
RequestInfo unauthReqInfo =
new RequestInfo(new Authentication(new User("joe", "not_superuser"), new RealmRef("test", "test", "node"), null),
requestInfo.getRequest(), requestInfo.getAction());
requestInfo.getRequest(), requestInfo.getAction(), null);
PlainActionFuture<AuthorizationInfo> future = new PlainActionFuture<>();
engine.resolveAuthorizationInfo(unauthReqInfo, future);
AuthorizationInfo authzInfo = future.actionGet();
Expand All @@ -129,7 +129,7 @@ public void testAuthorizeIndexAction() {
{
RequestInfo requestInfo =
new RequestInfo(new Authentication(new User("joe", "custom_superuser"), new RealmRef("test", "test", "node"), null),
new SearchRequest(), "indices:data/read/search");
new SearchRequest(), "indices:data/read/search", null);
PlainActionFuture<AuthorizationInfo> future = new PlainActionFuture<>();
engine.resolveAuthorizationInfo(requestInfo, future);
AuthorizationInfo authzInfo = future.actionGet();
Expand All @@ -150,7 +150,7 @@ public void testAuthorizeIndexAction() {
{
RequestInfo requestInfo =
new RequestInfo(new Authentication(new User("joe", "not_superuser"), new RealmRef("test", "test", "node"), null),
new SearchRequest(), "indices:data/read/search");
new SearchRequest(), "indices:data/read/search", null);
PlainActionFuture<AuthorizationInfo> future = new PlainActionFuture<>();
engine.resolveAuthorizationInfo(requestInfo, future);
AuthorizationInfo authzInfo = future.actionGet();
Expand All @@ -172,6 +172,6 @@ private RequestInfo getRequestInfo() {
final TransportRequest request = new TransportRequest() {};
final Authentication authentication =
new Authentication(new User("joe", "custom_superuser"), new RealmRef("test", "test", "node"), null);
return new RequestInfo(authentication, request, action);
return new RequestInfo(authentication, request, action, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
Expand Down Expand Up @@ -247,11 +248,19 @@ final class RequestInfo {
private final Authentication authentication;
private final TransportRequest request;
private final String action;

public RequestInfo(Authentication authentication, TransportRequest request, String action) {
this.authentication = authentication;
this.request = request;
this.action = action;
@Nullable
private final AuthorizationContext originatingAuthorizationContext;

public RequestInfo(
Authentication authentication,
TransportRequest request,
String action,
AuthorizationContext originatingContext
) {
this.authentication = Objects.requireNonNull(authentication);
this.request = Objects.requireNonNull(request);
this.action = Objects.requireNonNull(action);
this.originatingAuthorizationContext = originatingContext;
}

public String getAction() {
Expand All @@ -265,6 +274,27 @@ public Authentication getAuthentication() {
public TransportRequest getRequest() {
return request;
}

@Nullable
public AuthorizationContext getOriginatingAuthorizationContext() {
return originatingAuthorizationContext;
}

@Override
public String toString() {
return getClass().getSimpleName()
+ '{'
+ "authentication=["
+ authentication
+ "], request=["
+ request
+ "], action=["
+ action
+ ']'
+ ", parent=["
+ originatingAuthorizationContext
+ "]}";
}
}

/**
Expand Down Expand Up @@ -354,6 +384,31 @@ public IndicesAccessControl getIndicesAccessControl() {
}
}


final class AuthorizationContext {
private final String action;
private final AuthorizationInfo authorizationInfo;
private final IndicesAccessControl indicesAccessControl;

public AuthorizationContext(String action, AuthorizationInfo authorizationInfo, IndicesAccessControl accessControl) {
this.action = action;
this.authorizationInfo = authorizationInfo;
this.indicesAccessControl = accessControl;
}

public String getAction() {
return action;
}

public AuthorizationInfo getAuthorizationInfo() {
return authorizationInfo;
}

public IndicesAccessControl getIndicesAccessControl() {
return indicesAccessControl;
}
}

@FunctionalInterface
interface AsyncSupplier<V> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ public RunAsPermission runAs() {
throw new UnsupportedOperationException("cannot retrieve run_as permission on limited role");
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (super.equals(o) == false) {
return false;
}
LimitedRole that = (LimitedRole) o;
return this.limitedBy.equals(that.limitedBy);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), limitedBy);
}

@Override
public IndicesAccessControl authorize(String action, Set<String> requestedIndicesOrAliases,
Map<String, IndexAbstraction> aliasAndIndexLookup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,29 @@ public IndicesAccessControl authorize(String action, Set<String> requestedIndice
return new IndicesAccessControl(granted, indexPermissions);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Role that = (Role) o;
return Arrays.equals(this.names, that.names)
&& this.cluster.equals(that.cluster)
&& this.indices.equals(that.indices)
&& this.application.equals(that.application)
&& this.runAs.equals(that.runAs);
}

@Override
public int hashCode() {
int result = Objects.hash(cluster, indices, application, runAs);
result = 31 * result + Arrays.hashCode(names);
return result;
}

public static class Builder {

private final String[] names;
Expand Down

0 comments on commit c05c023

Please sign in to comment.