Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: perform topic permission checks for KSQL service principal #3261

Merged
merged 3 commits into from
Aug 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.parser.tree.CreateAsSelect;
import io.confluent.ksql.parser.tree.CreateSource;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.PrintTopic;
import io.confluent.ksql.parser.tree.Query;
Expand Down Expand Up @@ -51,6 +52,8 @@ public void checkAuthorization(
validateCreateAsSelect(serviceContext, metaStore, (CreateAsSelect)statement);
} else if (statement instanceof PrintTopic) {
validatePrintTopic(serviceContext, (PrintTopic)statement);
} else if (statement instanceof CreateSource) {
validateCreateSource(serviceContext, (CreateSource)statement);
}
}

Expand Down Expand Up @@ -111,6 +114,14 @@ private void validatePrintTopic(
checkAccess(serviceContext, printTopic.getTopic().toString(), AclOperation.READ);
}

private void validateCreateSource(
final ServiceContext serviceContext,
final CreateSource createSource
) {
final String sourceTopic = createSource.getProperties().getKafkaTopic();
checkAccess(serviceContext, sourceTopic, AclOperation.READ);
}

private String getSourceTopicName(final MetaStore metaStore, final String streamOrTable) {
final DataSource<?> dataSource = metaStore.getSource(streamOrTable);
if (dataSource == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class KsqlAuthorizationImplTest {
public class KsqlAuthorizationValidatorImplTest {

private static final LogicalSchema SCHEMA = LogicalSchema.of(SchemaBuilder
.struct()
Expand Down Expand Up @@ -376,6 +376,39 @@ public void shouldThrowWhenThrowPrintTopicWithoutReadPermissionsDenied() {
authorizationValidator.checkAuthorization(serviceContext, metaStore, statement);
}

@Test
public void shouldCreateSourceWithReadPermissionsAllowed() {
// Given:
givenTopicPermissions(TOPIC_1, Collections.singleton(AclOperation.READ));
final Statement statement = givenStatement(String.format(
"CREATE STREAM s1 WITH (kafka_topic='%s', value_format='JSON');", TOPIC_NAME_1)
);

// When:
authorizationValidator.checkAuthorization(serviceContext, metaStore, statement);

// Then:
// Above command should not throw any exception
}

@Test
public void shouldThrowWhenCreateSourceWithoutReadPermissionsDenied() {
// Given:
givenTopicPermissions(TOPIC_1, Collections.singleton(AclOperation.WRITE));
final Statement statement = givenStatement(String.format(
"CREATE STREAM s1 WITH (kafka_topic='%s', value_format='JSON');", TOPIC_NAME_1)
);

// Then:
expectedException.expect(KsqlTopicAuthorizationException.class);
expectedException.expectMessage(String.format(
"Authorization denied to Read on topic(s): [%s]", TOPIC_1.name()
));

// When:
authorizationValidator.checkAuthorization(serviceContext, metaStore, statement);
}

@Test
public void shouldThrowExceptionWhenTopicClientFails() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
package io.confluent.ksql.rest.server.computation;

import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.entity.KsqlEntity;
import io.confluent.ksql.rest.server.execution.StatementExecutor;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
Expand All @@ -40,15 +42,20 @@ public class DistributingExecutor implements StatementExecutor<Statement> {
private final CommandQueue commandQueue;
private final Duration distributedCmdResponseTimeout;
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
private final KsqlAuthorizationValidator authorizationValidator;

public DistributingExecutor(
final CommandQueue commandQueue,
final Duration distributedCmdResponseTimeout,
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory) {
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory,
final KsqlAuthorizationValidator authorizationValidator
) {
this.commandQueue = Objects.requireNonNull(commandQueue, "commandQueue");
this.distributedCmdResponseTimeout =
Objects.requireNonNull(distributedCmdResponseTimeout, "distributedCmdResponseTimeout");
this.injectorFactory = Objects.requireNonNull(injectorFactory, "injectorFactory");
this.authorizationValidator =
Objects.requireNonNull(authorizationValidator, "authorizationValidator");
}

@Override
Expand All @@ -61,6 +68,8 @@ public Optional<KsqlEntity> execute(
.apply(executionContext, serviceContext)
.inject(statement);

checkAuthorization(injected, serviceContext, executionContext);

try {
final QueuedCommandStatus queuedCommandStatus = commandQueue.enqueueCommand(injected);
final CommandStatus commandStatus = queuedCommandStatus
Expand All @@ -78,4 +87,27 @@ public Optional<KsqlEntity> execute(
statement.getStatementText()), e);
}
}

private void checkAuthorization(
final ConfiguredStatement<?> configured,
final ServiceContext userServiceContext,
final KsqlExecutionContext serverExecutionContext
) {
final Statement statement = configured.getStatement();
final MetaStore metaStore = serverExecutionContext.getMetaStore();

// Check the User will be permitted to execute this statement
authorizationValidator.checkAuthorization(userServiceContext, metaStore, statement);

try {
// Check the KSQL service principal will be permitted too
authorizationValidator.checkAuthorization(
serverExecutionContext.getServiceContext(),
metaStore,
statement
);
} catch (final Exception e) {
throw new KsqlServerException("The KSQL server is not permitted to execute the command", e);
}
spena marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@ public KsqlResource(
CustomValidators.VALIDATOR_MAP,
injectorFactory,
ksqlEngine::createSandbox,
ksqlConfig,
authorizationValidator);
ksqlConfig
);
this.handler = new RequestHandler(
CustomExecutors.EXECUTOR_MAP,
new DistributingExecutor(
commandQueue,
distributedCmdResponseTimeout,
injectorFactory),
injectorFactory,
authorizationValidator
),
ksqlEngine,
ksqlConfig,
new DefaultCommandQueueSync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.confluent.ksql.parser.tree.RunScript;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.util.QueryCapacityUtil;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
Expand Down Expand Up @@ -55,7 +54,6 @@ public class RequestValidator {
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
private final Function<ServiceContext, KsqlExecutionContext> snapshotSupplier;
private final KsqlConfig ksqlConfig;
private final KsqlAuthorizationValidator authorizationValidator;

/**
* @param customValidators a map describing how to validate each statement of type
Expand All @@ -69,14 +67,12 @@ public RequestValidator(
final Map<Class<? extends Statement>, StatementValidator<?>> customValidators,
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory,
final Function<ServiceContext, KsqlExecutionContext> snapshotSupplier,
final KsqlConfig ksqlConfig,
final KsqlAuthorizationValidator authorizationValidator
final KsqlConfig ksqlConfig
) {
this.customValidators = requireNonNull(customValidators, "customValidators");
this.injectorFactory = requireNonNull(injectorFactory, "injectorFactory");
this.snapshotSupplier = requireNonNull(snapshotSupplier, "snapshotSupplier");
this.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig");
this.authorizationValidator = authorizationValidator;
}

/**
Expand Down Expand Up @@ -143,13 +139,6 @@ private <T extends Statement> int validate(
customValidator.validate(configured, executionContext, serviceContext);
} else if (KsqlEngine.isExecutableStatement(configured.getStatement())) {
final ConfiguredStatement<?> statementInjected = injector.inject(configured);

authorizationValidator.checkAuthorization(
serviceContext,
executionContext.getMetaStore(),
statementInjected.getStatement()
);

executionContext.execute(serviceContext, statementInjected);
} else {
throw new KsqlStatementException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.exception.KsqlTopicAuthorizationException;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.tree.ListProperties;
import io.confluent.ksql.parser.tree.Statement;
Expand All @@ -32,6 +38,7 @@
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.server.computation.CommandId.Action;
import io.confluent.ksql.rest.server.computation.CommandId.Type;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
Expand Down Expand Up @@ -73,6 +80,9 @@ public class DistributingExecutorTest {
@Mock ServiceContext serviceContext;
@Mock Injector schemaInjector;
@Mock Injector topicInjector;
@Mock KsqlAuthorizationValidator authorizationValidator;
@Mock KsqlExecutionContext executionContext;
@Mock MetaStore metaStore;

private DistributingExecutor distributor;
private AtomicLong scnCounter;
Expand All @@ -86,17 +96,21 @@ public void setUp() throws InterruptedException {
when(status.tryWaitForFinalStatus(any())).thenReturn(SUCCESS_STATUS);
when(status.getCommandId()).thenReturn(CS_COMMAND);
when(status.getCommandSequenceNumber()).thenAnswer(inv -> scnCounter.incrementAndGet());
when(executionContext.getMetaStore()).thenReturn(metaStore);
when(executionContext.getServiceContext()).thenReturn(serviceContext);

distributor = new DistributingExecutor(
queue,
DURATION_10_MS,
(ec, sc) -> InjectorChain.of(schemaInjector, topicInjector));
(ec, sc) -> InjectorChain.of(schemaInjector, topicInjector),
authorizationValidator
);
}

@Test
public void shouldEnqueueSuccessfulCommand() throws InterruptedException {
// When:
distributor.execute(EMPTY_STATEMENT, null, serviceContext);
distributor.execute(EMPTY_STATEMENT, executionContext, serviceContext);

// Then:
verify(queue, times(1)).enqueueCommand(eq(EMPTY_STATEMENT));
Expand All @@ -105,7 +119,7 @@ public void shouldEnqueueSuccessfulCommand() throws InterruptedException {
@Test
public void shouldInferSchemas() {
// When:
distributor.execute(EMPTY_STATEMENT, null, serviceContext);
distributor.execute(EMPTY_STATEMENT, executionContext, serviceContext);

// Then:
verify(schemaInjector, times(1)).inject(eq(EMPTY_STATEMENT));
Expand All @@ -117,7 +131,7 @@ public void shouldReturnCommandStatus() {
final CommandStatusEntity commandStatusEntity =
(CommandStatusEntity) distributor.execute(
EMPTY_STATEMENT,
null,
executionContext,
serviceContext
)
.orElseThrow(null);
Expand Down Expand Up @@ -150,7 +164,7 @@ public void shouldThrowExceptionOnFailureToEnqueue() {
expectedException.expectCause(is(cause));

// When:
distributor.execute(configured, null, serviceContext);
distributor.execute(configured, executionContext, serviceContext);
}

@Test
Expand All @@ -167,7 +181,43 @@ public void shouldThrowFailureIfCannotInferSchema() {
expectedException.expectMessage("Could not infer!");

// When:
distributor.execute(configured, null, serviceContext);
distributor.execute(configured, executionContext, serviceContext);
}

@Test
public void shouldThrowExceptionIfUserServiceContextIsDeniedAuthorization() {
// Given:
final ServiceContext userServiceContext = mock(ServiceContext.class);
final PreparedStatement<Statement> preparedStatement =
PreparedStatement.of("", new ListProperties(Optional.empty()));
final ConfiguredStatement<Statement> configured =
ConfiguredStatement.of(preparedStatement, ImmutableMap.of(), KSQL_CONFIG);
doThrow(KsqlTopicAuthorizationException.class).when(authorizationValidator)
.checkAuthorization(eq(userServiceContext), any(), eq(configured.getStatement()));

// Expect:
expectedException.expect(KsqlTopicAuthorizationException.class);

// When:
distributor.execute(configured, executionContext, userServiceContext);
}

@Test
public void shouldThrowServerExceptionIfServerServiceContextIsDeniedAuthorization() {
// Given:
final ServiceContext userServiceContext = mock(ServiceContext.class);
final PreparedStatement<Statement> preparedStatement =
PreparedStatement.of("", new ListProperties(Optional.empty()));
final ConfiguredStatement<Statement> configured =
ConfiguredStatement.of(preparedStatement, ImmutableMap.of(), KSQL_CONFIG);
doThrow(KsqlTopicAuthorizationException.class).when(authorizationValidator)
.checkAuthorization(eq(serviceContext), any(), eq(configured.getStatement()));

// Expect:
expectedException.expect(KsqlServerException.class);
expectedException.expectCause(is(instanceOf(KsqlTopicAuthorizationException.class)));

// When:
distributor.execute(configured, executionContext, userServiceContext);
}
}
Loading