Skip to content

Commit

Permalink
feat: perform topic permission checks for KSQL service principal (#3261)
Browse files Browse the repository at this point in the history
Perform permissions checks on the KSQL service principal to display
better AUTH error messages on the CLI console, and thus providing
better feedback to the users about the authorization error.
  • Loading branch information
spena committed Aug 27, 2019
1 parent 6a50fca commit ba1f613
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.parser.tree.CreateAsSelect;
import io.confluent.ksql.parser.tree.CreateSource;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.PrintTopic;
import io.confluent.ksql.parser.tree.Query;
Expand Down Expand Up @@ -51,6 +52,8 @@ public void checkAuthorization(
validateCreateAsSelect(serviceContext, metaStore, (CreateAsSelect)statement);
} else if (statement instanceof PrintTopic) {
validatePrintTopic(serviceContext, (PrintTopic)statement);
} else if (statement instanceof CreateSource) {
validateCreateSource(serviceContext, (CreateSource)statement);
}
}

Expand Down Expand Up @@ -111,6 +114,14 @@ private void validatePrintTopic(
checkAccess(serviceContext, printTopic.getTopic().toString(), AclOperation.READ);
}

private void validateCreateSource(
final ServiceContext serviceContext,
final CreateSource createSource
) {
final String sourceTopic = createSource.getProperties().getKafkaTopic();
checkAccess(serviceContext, sourceTopic, AclOperation.READ);
}

private String getSourceTopicName(final MetaStore metaStore, final String streamOrTable) {
final DataSource<?> dataSource = metaStore.getSource(streamOrTable);
if (dataSource == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class KsqlAuthorizationImplTest {
public class KsqlAuthorizationValidatorImplTest {

private static final LogicalSchema SCHEMA = LogicalSchema.of(SchemaBuilder
.struct()
Expand Down Expand Up @@ -376,6 +376,39 @@ public void shouldThrowWhenThrowPrintTopicWithoutReadPermissionsDenied() {
authorizationValidator.checkAuthorization(serviceContext, metaStore, statement);
}

@Test
public void shouldCreateSourceWithReadPermissionsAllowed() {
// Given:
givenTopicPermissions(TOPIC_1, Collections.singleton(AclOperation.READ));
final Statement statement = givenStatement(String.format(
"CREATE STREAM s1 WITH (kafka_topic='%s', value_format='JSON');", TOPIC_NAME_1)
);

// When:
authorizationValidator.checkAuthorization(serviceContext, metaStore, statement);

// Then:
// Above command should not throw any exception
}

@Test
public void shouldThrowWhenCreateSourceWithoutReadPermissionsDenied() {
// Given:
givenTopicPermissions(TOPIC_1, Collections.singleton(AclOperation.WRITE));
final Statement statement = givenStatement(String.format(
"CREATE STREAM s1 WITH (kafka_topic='%s', value_format='JSON');", TOPIC_NAME_1)
);

// Then:
expectedException.expect(KsqlTopicAuthorizationException.class);
expectedException.expectMessage(String.format(
"Authorization denied to Read on topic(s): [%s]", TOPIC_1.name()
));

// When:
authorizationValidator.checkAuthorization(serviceContext, metaStore, statement);
}

@Test
public void shouldThrowExceptionWhenTopicClientFails() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
package io.confluent.ksql.rest.server.computation;

import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.entity.KsqlEntity;
import io.confluent.ksql.rest.server.execution.StatementExecutor;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
Expand All @@ -40,15 +42,20 @@ public class DistributingExecutor implements StatementExecutor<Statement> {
private final CommandQueue commandQueue;
private final Duration distributedCmdResponseTimeout;
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
private final KsqlAuthorizationValidator authorizationValidator;

public DistributingExecutor(
final CommandQueue commandQueue,
final Duration distributedCmdResponseTimeout,
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory) {
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory,
final KsqlAuthorizationValidator authorizationValidator
) {
this.commandQueue = Objects.requireNonNull(commandQueue, "commandQueue");
this.distributedCmdResponseTimeout =
Objects.requireNonNull(distributedCmdResponseTimeout, "distributedCmdResponseTimeout");
this.injectorFactory = Objects.requireNonNull(injectorFactory, "injectorFactory");
this.authorizationValidator =
Objects.requireNonNull(authorizationValidator, "authorizationValidator");
}

@Override
Expand All @@ -61,6 +68,8 @@ public Optional<KsqlEntity> execute(
.apply(executionContext, serviceContext)
.inject(statement);

checkAuthorization(injected, serviceContext, executionContext);

try {
final QueuedCommandStatus queuedCommandStatus = commandQueue.enqueueCommand(injected);
final CommandStatus commandStatus = queuedCommandStatus
Expand All @@ -78,4 +87,27 @@ public Optional<KsqlEntity> execute(
statement.getStatementText()), e);
}
}

private void checkAuthorization(
final ConfiguredStatement<?> configured,
final ServiceContext userServiceContext,
final KsqlExecutionContext serverExecutionContext
) {
final Statement statement = configured.getStatement();
final MetaStore metaStore = serverExecutionContext.getMetaStore();

// Check the User will be permitted to execute this statement
authorizationValidator.checkAuthorization(userServiceContext, metaStore, statement);

try {
// Check the KSQL service principal will be permitted too
authorizationValidator.checkAuthorization(
serverExecutionContext.getServiceContext(),
metaStore,
statement
);
} catch (final Exception e) {
throw new KsqlServerException("The KSQL server is not permitted to execute the command", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@ public KsqlResource(
CustomValidators.VALIDATOR_MAP,
injectorFactory,
ksqlEngine::createSandbox,
ksqlConfig,
authorizationValidator);
ksqlConfig
);
this.handler = new RequestHandler(
CustomExecutors.EXECUTOR_MAP,
new DistributingExecutor(
commandQueue,
distributedCmdResponseTimeout,
injectorFactory),
injectorFactory,
authorizationValidator
),
ksqlEngine,
ksqlConfig,
new DefaultCommandQueueSync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.confluent.ksql.parser.tree.RunScript;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.util.QueryCapacityUtil;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
Expand Down Expand Up @@ -55,7 +54,6 @@ public class RequestValidator {
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
private final Function<ServiceContext, KsqlExecutionContext> snapshotSupplier;
private final KsqlConfig ksqlConfig;
private final KsqlAuthorizationValidator authorizationValidator;

/**
* @param customValidators a map describing how to validate each statement of type
Expand All @@ -69,14 +67,12 @@ public RequestValidator(
final Map<Class<? extends Statement>, StatementValidator<?>> customValidators,
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory,
final Function<ServiceContext, KsqlExecutionContext> snapshotSupplier,
final KsqlConfig ksqlConfig,
final KsqlAuthorizationValidator authorizationValidator
final KsqlConfig ksqlConfig
) {
this.customValidators = requireNonNull(customValidators, "customValidators");
this.injectorFactory = requireNonNull(injectorFactory, "injectorFactory");
this.snapshotSupplier = requireNonNull(snapshotSupplier, "snapshotSupplier");
this.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig");
this.authorizationValidator = authorizationValidator;
}

/**
Expand Down Expand Up @@ -143,13 +139,6 @@ private <T extends Statement> int validate(
customValidator.validate(configured, executionContext, serviceContext);
} else if (KsqlEngine.isExecutableStatement(configured.getStatement())) {
final ConfiguredStatement<?> statementInjected = injector.inject(configured);

authorizationValidator.checkAuthorization(
serviceContext,
executionContext.getMetaStore(),
statementInjected.getStatement()
);

executionContext.execute(serviceContext, statementInjected);
} else {
throw new KsqlStatementException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.exception.KsqlTopicAuthorizationException;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.tree.ListProperties;
import io.confluent.ksql.parser.tree.Statement;
Expand All @@ -32,6 +38,7 @@
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.server.computation.CommandId.Action;
import io.confluent.ksql.rest.server.computation.CommandId.Type;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
Expand Down Expand Up @@ -73,6 +80,9 @@ public class DistributingExecutorTest {
@Mock ServiceContext serviceContext;
@Mock Injector schemaInjector;
@Mock Injector topicInjector;
@Mock KsqlAuthorizationValidator authorizationValidator;
@Mock KsqlExecutionContext executionContext;
@Mock MetaStore metaStore;

private DistributingExecutor distributor;
private AtomicLong scnCounter;
Expand All @@ -86,17 +96,21 @@ public void setUp() throws InterruptedException {
when(status.tryWaitForFinalStatus(any())).thenReturn(SUCCESS_STATUS);
when(status.getCommandId()).thenReturn(CS_COMMAND);
when(status.getCommandSequenceNumber()).thenAnswer(inv -> scnCounter.incrementAndGet());
when(executionContext.getMetaStore()).thenReturn(metaStore);
when(executionContext.getServiceContext()).thenReturn(serviceContext);

distributor = new DistributingExecutor(
queue,
DURATION_10_MS,
(ec, sc) -> InjectorChain.of(schemaInjector, topicInjector));
(ec, sc) -> InjectorChain.of(schemaInjector, topicInjector),
authorizationValidator
);
}

@Test
public void shouldEnqueueSuccessfulCommand() throws InterruptedException {
// When:
distributor.execute(EMPTY_STATEMENT, null, serviceContext);
distributor.execute(EMPTY_STATEMENT, executionContext, serviceContext);

// Then:
verify(queue, times(1)).enqueueCommand(eq(EMPTY_STATEMENT));
Expand All @@ -105,7 +119,7 @@ public void shouldEnqueueSuccessfulCommand() throws InterruptedException {
@Test
public void shouldInferSchemas() {
// When:
distributor.execute(EMPTY_STATEMENT, null, serviceContext);
distributor.execute(EMPTY_STATEMENT, executionContext, serviceContext);

// Then:
verify(schemaInjector, times(1)).inject(eq(EMPTY_STATEMENT));
Expand All @@ -117,7 +131,7 @@ public void shouldReturnCommandStatus() {
final CommandStatusEntity commandStatusEntity =
(CommandStatusEntity) distributor.execute(
EMPTY_STATEMENT,
null,
executionContext,
serviceContext
)
.orElseThrow(null);
Expand Down Expand Up @@ -150,7 +164,7 @@ public void shouldThrowExceptionOnFailureToEnqueue() {
expectedException.expectCause(is(cause));

// When:
distributor.execute(configured, null, serviceContext);
distributor.execute(configured, executionContext, serviceContext);
}

@Test
Expand All @@ -167,7 +181,43 @@ public void shouldThrowFailureIfCannotInferSchema() {
expectedException.expectMessage("Could not infer!");

// When:
distributor.execute(configured, null, serviceContext);
distributor.execute(configured, executionContext, serviceContext);
}

@Test
public void shouldThrowExceptionIfUserServiceContextIsDeniedAuthorization() {
// Given:
final ServiceContext userServiceContext = mock(ServiceContext.class);
final PreparedStatement<Statement> preparedStatement =
PreparedStatement.of("", new ListProperties(Optional.empty()));
final ConfiguredStatement<Statement> configured =
ConfiguredStatement.of(preparedStatement, ImmutableMap.of(), KSQL_CONFIG);
doThrow(KsqlTopicAuthorizationException.class).when(authorizationValidator)
.checkAuthorization(eq(userServiceContext), any(), eq(configured.getStatement()));

// Expect:
expectedException.expect(KsqlTopicAuthorizationException.class);

// When:
distributor.execute(configured, executionContext, userServiceContext);
}

@Test
public void shouldThrowServerExceptionIfServerServiceContextIsDeniedAuthorization() {
// Given:
final ServiceContext userServiceContext = mock(ServiceContext.class);
final PreparedStatement<Statement> preparedStatement =
PreparedStatement.of("", new ListProperties(Optional.empty()));
final ConfiguredStatement<Statement> configured =
ConfiguredStatement.of(preparedStatement, ImmutableMap.of(), KSQL_CONFIG);
doThrow(KsqlTopicAuthorizationException.class).when(authorizationValidator)
.checkAuthorization(eq(serviceContext), any(), eq(configured.getStatement()));

// Expect:
expectedException.expect(KsqlServerException.class);
expectedException.expectCause(is(instanceOf(KsqlTopicAuthorizationException.class)));

// When:
distributor.execute(configured, executionContext, userServiceContext);
}
}

0 comments on commit ba1f613

Please sign in to comment.