diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopic.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopic.java index d554bd585b1e..32348a60fc4a 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopic.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopic.java @@ -16,9 +16,6 @@ package io.confluent.ksql.rest.server; import com.google.common.collect.Lists; -import io.confluent.ksql.rest.entity.CommandId; -import io.confluent.ksql.rest.server.computation.Command; -import io.confluent.ksql.rest.server.computation.InternalTopicSerdes; import io.confluent.ksql.rest.server.computation.QueuedCommand; import java.time.Duration; import java.util.Collections; @@ -31,6 +28,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,7 +37,7 @@ public class CommandTopic { private static final Logger log = LoggerFactory.getLogger(CommandTopic.class); private final TopicPartition commandTopicPartition; - private Consumer commandConsumer = null; + private Consumer commandConsumer; private final String commandTopicName; private CommandTopicBackup commandTopicBackup; @@ -52,8 +50,8 @@ public CommandTopic( commandTopicName, new KafkaConsumer<>( Objects.requireNonNull(kafkaConsumerProperties, "kafkaClientProperties"), - InternalTopicSerdes.deserializer(CommandId.class), - InternalTopicSerdes.deserializer(Command.class) + new ByteArrayDeserializer(), + new ByteArrayDeserializer() ), commandTopicBackup ); @@ -61,7 +59,7 @@ public CommandTopic( CommandTopic( final String commandTopicName, - final Consumer commandConsumer, + final Consumer commandConsumer, final CommandTopicBackup commandTopicBackup ) { this.commandTopicPartition = new TopicPartition(commandTopicName, 0); @@ -79,11 +77,11 @@ public void start() { commandConsumer.assign(Collections.singleton(commandTopicPartition)); } - public Iterable> getNewCommands(final Duration timeout) { - final Iterable> iterable = commandConsumer.poll(timeout); + public Iterable> getNewCommands(final Duration timeout) { + final Iterable> iterable = commandConsumer.poll(timeout); if (iterable != null) { - iterable.forEach(record -> backupRecord(record)); + iterable.forEach(this::backupRecord); } return iterable; @@ -96,11 +94,11 @@ public List getRestoreCommands(final Duration duration) { Collections.singletonList(commandTopicPartition)); log.debug("Reading prior command records"); - ConsumerRecords records = + ConsumerRecords records = commandConsumer.poll(duration); while (!records.isEmpty()) { log.debug("Received {} records from poll", records.count()); - for (final ConsumerRecord record : records) { + for (final ConsumerRecord record : records) { backupRecord(record); if (record.value() == null) { @@ -136,7 +134,7 @@ public void close() { commandTopicBackup.close(); } - private void backupRecord(final ConsumerRecord record) { + private void backupRecord(final ConsumerRecord record) { commandTopicBackup.writeRecord(record); } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackup.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackup.java index 1c5aa811a884..d12b23bb5d1b 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackup.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackup.java @@ -15,14 +15,12 @@ package io.confluent.ksql.rest.server; -import io.confluent.ksql.rest.entity.CommandId; -import io.confluent.ksql.rest.server.computation.Command; import org.apache.kafka.clients.consumer.ConsumerRecord; public interface CommandTopicBackup { void initialize(); - void writeRecord(ConsumerRecord record); + void writeRecord(ConsumerRecord record); void close(); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupImpl.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupImpl.java index d2a841a9943c..b0fe42838c6c 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupImpl.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupImpl.java @@ -19,6 +19,7 @@ import com.google.common.base.Ticker; import io.confluent.ksql.rest.entity.CommandId; import io.confluent.ksql.rest.server.computation.Command; +import io.confluent.ksql.rest.server.computation.InternalTopicSerdes; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlServerException; import io.confluent.ksql.util.Pair; @@ -125,7 +126,27 @@ private boolean isRecordInLatestReplay(final ConsumerRecord } @Override - public void writeRecord(final ConsumerRecord record) { + public void writeRecord(final ConsumerRecord record) { + final ConsumerRecord deserializedRecord; + try { + deserializedRecord = new ConsumerRecord<>( + record.topic(), + record.partition(), + record.offset(), + InternalTopicSerdes.deserializer(CommandId.class) + .deserialize(record.topic(), record.key()), + InternalTopicSerdes.deserializer(Command.class) + .deserialize(record.topic(), record.value()) + ); + } catch (Exception e) { + LOG.error("Failed to deserialize command topic record when backing it up: {}:{}", + record.key(), record.value()); + return; + } + writeCommandToBackup(deserializedRecord); + } + + void writeCommandToBackup(final ConsumerRecord record) { if (isRestoring()) { if (isRecordInLatestReplay(record)) { // Ignore backup because record was already replayed diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupNoOp.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupNoOp.java index 5fb860def9b8..efaafdc46629 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupNoOp.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/CommandTopicBackupNoOp.java @@ -15,8 +15,6 @@ package io.confluent.ksql.rest.server; -import io.confluent.ksql.rest.entity.CommandId; -import io.confluent.ksql.rest.server.computation.Command; import org.apache.kafka.clients.consumer.ConsumerRecord; public class CommandTopicBackupNoOp implements CommandTopicBackup { @@ -26,7 +24,7 @@ public void initialize() { } @Override - public void writeRecord(final ConsumerRecord record) { + public void writeRecord(final ConsumerRecord record) { // no-op } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java index b565bb83e86c..e68f8153bb35 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java @@ -60,9 +60,11 @@ import io.confluent.ksql.rest.entity.SourceInfo; import io.confluent.ksql.rest.entity.StreamsList; import io.confluent.ksql.rest.server.HeartbeatAgent.Builder; +import io.confluent.ksql.rest.server.computation.Command; import io.confluent.ksql.rest.server.computation.CommandRunner; import io.confluent.ksql.rest.server.computation.CommandStore; import io.confluent.ksql.rest.server.computation.InteractiveStatementExecutor; +import io.confluent.ksql.rest.server.computation.InternalTopicSerdes; import io.confluent.ksql.rest.server.execution.PullQueryExecutor; import io.confluent.ksql.rest.server.resources.ClusterStatusResource; import io.confluent.ksql.rest.server.resources.HealthCheckResource; @@ -742,7 +744,8 @@ static KsqlRestApplication buildApplication( ksqlConfig.getString(KsqlConfig.KSQL_SERVICE_ID_CONFIG), Duration.ofMillis(restConfig.getLong( KsqlRestConfig.KSQL_COMMAND_RUNNER_BLOCKED_THRESHHOLD_ERROR_MS)), - metricsPrefix + metricsPrefix, + InternalTopicSerdes.deserializer(Command.class) ); final QueryMonitor queryMonitor = new QueryMonitor(ksqlConfig, ksqlEngine); diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunner.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunner.java index c0dbae60fcdc..adff957ac062 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunner.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunner.java @@ -27,6 +27,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; @@ -35,8 +36,11 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; +import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.serialization.Deserializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -70,9 +74,14 @@ public class CommandRunner implements Closeable { private final Duration commandRunnerHealthTimeout; private final Clock clock; + private final Deserializer commandDeserializer; + private final Consumer incompatibleCommandChecker; + private boolean deserializationErrorThrown; + public enum CommandRunnerStatus { RUNNING, - ERROR + ERROR, + DEGRADED } public CommandRunner( @@ -83,7 +92,8 @@ public CommandRunner( final ServerState serverState, final String ksqlServiceId, final Duration commandRunnerHealthTimeout, - final String metricsGroupPrefix + final String metricsGroupPrefix, + final Deserializer commandDeserializer ) { this( statementExecutor, @@ -96,7 +106,12 @@ public CommandRunner( commandRunnerHealthTimeout, metricsGroupPrefix, Clock.systemUTC(), - RestoreCommandsCompactor::compact + RestoreCommandsCompactor::compact, + queuedCommand -> { + queuedCommand.getAndDeserializeCommandId(); + queuedCommand.getAndDeserializeCommand(commandDeserializer); + }, + commandDeserializer ); } @@ -113,7 +128,9 @@ public CommandRunner( final Duration commandRunnerHealthTimeout, final String metricsGroupPrefix, final Clock clock, - final Function, List> compactor + final Function, List> compactor, + final Consumer incompatibleCommandChecker, + final Deserializer commandDeserializer ) { // CHECKSTYLE_RULES.ON: ParameterNumberCheck this.statementExecutor = Objects.requireNonNull(statementExecutor, "statementExecutor"); @@ -130,6 +147,11 @@ public CommandRunner( new CommandRunnerStatusMetric(ksqlServiceId, this, metricsGroupPrefix); this.clock = Objects.requireNonNull(clock, "clock"); this.compactor = Objects.requireNonNull(compactor, "compactor"); + this.incompatibleCommandChecker = + Objects.requireNonNull(incompatibleCommandChecker, "incompatibleCommandChecker"); + this.commandDeserializer = + Objects.requireNonNull(commandDeserializer, "commandDeserializer"); + this.deserializationErrorThrown = false; } /** @@ -146,6 +168,16 @@ public void start() { */ @Override public void close() { + if (!closed) { + closeEarly(); + } + commandRunnerStatusMetric.close(); + } + + /** + * Closes the poll-execute loop before the server shuts down + */ + private void closeEarly() { try { closed = true; commandStore.wakeup(); @@ -153,7 +185,6 @@ public void close() { } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } - commandRunnerStatusMetric.close(); } /** @@ -162,17 +193,19 @@ public void close() { public void processPriorCommands() { try { final List restoreCommands = commandStore.getRestoreCommands(); + final List compatibleCommands = checkForIncompatibleCommands(restoreCommands); - LOG.info("Restoring previous state from {} commands.", restoreCommands.size()); + LOG.info("Restoring previous state from {} commands.", compatibleCommands.size()); - final Optional terminateCmd = findTerminateCommand(restoreCommands); + final Optional terminateCmd = + findTerminateCommand(compatibleCommands, commandDeserializer); if (terminateCmd.isPresent()) { LOG.info("Cluster previously terminated: terminating."); - terminateCluster(terminateCmd.get().getCommand()); + terminateCluster(terminateCmd.get().getAndDeserializeCommand(commandDeserializer)); return; } - final List compacted = compactor.apply(restoreCommands); + final List compacted = compactor.apply(compatibleCommands); compacted.forEach( command -> { @@ -211,14 +244,16 @@ void fetchAndRunCommands() { return; } - final Optional terminateCmd = findTerminateCommand(commands); + final List compatibleCommands = checkForIncompatibleCommands(commands); + final Optional terminateCmd = + findTerminateCommand(compatibleCommands, commandDeserializer); if (terminateCmd.isPresent()) { - terminateCluster(terminateCmd.get().getCommand()); + terminateCluster(terminateCmd.get().getAndDeserializeCommand(commandDeserializer)); return; } - LOG.debug("Found {} new writes to command topic", commands.size()); - for (final QueuedCommand command : commands) { + LOG.debug("Found {} new writes to command topic", compatibleCommands.size()); + for (final QueuedCommand command : compatibleCommands) { if (closed) { return; } @@ -228,14 +263,16 @@ void fetchAndRunCommands() { } private void executeStatement(final QueuedCommand queuedCommand) { - LOG.info("Executing statement: " + queuedCommand.getCommand().getStatement()); + final String commandStatement = + queuedCommand.getAndDeserializeCommand(commandDeserializer).getStatement(); + LOG.info("Executing statement: " + commandStatement); final Runnable task = () -> { if (closed) { LOG.info("Execution aborted as system is closing down"); } else { statementExecutor.handleStatement(queuedCommand); - LOG.info("Executed statement: " + queuedCommand.getCommand().getStatement()); + LOG.info("Executed statement: " + commandStatement); } }; @@ -251,10 +288,11 @@ private void executeStatement(final QueuedCommand queuedCommand) { } private static Optional findTerminateCommand( - final List restoreCommands + final List restoreCommands, + final Deserializer commandDeserializer ) { return restoreCommands.stream() - .filter(command -> command.getCommand().getStatement() + .filter(command -> command.getAndDeserializeCommand(commandDeserializer).getStatement() .equalsIgnoreCase(TerminateCluster.TERMINATE_CLUSTER_STATEMENT_TEXT)) .findFirst(); } @@ -272,6 +310,10 @@ private void terminateCluster(final Command command) { } CommandRunnerStatus checkCommandRunnerStatus() { + if (deserializationErrorThrown) { + return CommandRunnerStatus.DEGRADED; + } + final Pair currentCommand = currentCommandRef.get(); if (currentCommand == null) { return lastPollTime.get() == null @@ -285,14 +327,33 @@ CommandRunnerStatus checkCommandRunnerStatus() { ? CommandRunnerStatus.RUNNING : CommandRunnerStatus.ERROR; } + private List checkForIncompatibleCommands(final List commands) { + final List compatibleCommands = new ArrayList<>(); + try { + for (final QueuedCommand command : commands) { + incompatibleCommandChecker.accept(command); + compatibleCommands.add(command); + } + } catch (SerializationException e) { + LOG.info("Deserialization error detected when processing record", e); + deserializationErrorThrown = true; + } + return compatibleCommands; + } + private class Runner implements Runnable { @Override public void run() { try { while (!closed) { - LOG.trace("Polling for new writes to command topic"); - fetchAndRunCommands(); + if (deserializationErrorThrown) { + LOG.warn("CommandRunner entering degraded state after failing to deserialize command"); + closeEarly(); + } else { + LOG.trace("Polling for new writes to command topic"); + fetchAndRunCommands(); + } } } catch (final WakeupException wue) { if (!closed) { diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetric.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetric.java index b9a5f84b1fe8..4f605300cc19 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetric.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetric.java @@ -35,7 +35,6 @@ public class CommandRunnerStatusMetric implements Closeable { private final Metrics metrics; private final MetricName metricName; - private final String metricGroupName; CommandRunnerStatusMetric( final String ksqlServiceId, @@ -58,7 +57,7 @@ public class CommandRunnerStatusMetric implements Closeable { final String metricsGroupPrefix ) { this.metrics = Objects.requireNonNull(metrics, "metrics"); - this.metricGroupName = metricsGroupPrefix + METRIC_GROUP_POST_FIX; + final String metricGroupName = metricsGroupPrefix + METRIC_GROUP_POST_FIX; this.metricName = metrics.metricName( "status", ReservedInternalTopics.KSQL_INTERNAL_TOPIC_PREFIX + ksqlServiceId + metricGroupName, diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java index 9d4a697f9fc1..6d24f6ccfe0f 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java @@ -37,8 +37,10 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; + import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.producer.KafkaProducer; import org.apache.kafka.clients.producer.Producer; @@ -47,6 +49,11 @@ import org.apache.kafka.clients.producer.RecordMetadata; import org.apache.kafka.common.IsolationLevel; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Wrapper class for the command topic. Used for reading from the topic (either all messages from @@ -55,6 +62,7 @@ // CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling public class CommandStore implements CommandQueue, Closeable { + private static final Logger LOG = LoggerFactory.getLogger(CommandStore.class); private static final Duration POLLING_TIMEOUT_FOR_COMMAND_TOPIC = Duration.ofMillis(5000); private static final int COMMAND_TOPIC_PARTITION = 0; @@ -66,6 +74,10 @@ public class CommandStore implements CommandQueue, Closeable { private final Duration commandQueueCatchupTimeout; private final Map kafkaConsumerProperties; private final Map kafkaProducerProperties; + private final Serializer commandIdSerializer; + private final Serializer commandSerializer; + private final Deserializer commandIdDeserializer; + public static final class Factory { @@ -116,7 +128,10 @@ public static CommandStore create( new SequenceNumberFutureStore(), kafkaConsumerProperties, kafkaProducerProperties, - commandQueueCatchupTimeout + commandQueueCatchupTimeout, + InternalTopicSerdes.serializer(), + InternalTopicSerdes.serializer(), + InternalTopicSerdes.deserializer(CommandId.class) ); } } @@ -127,7 +142,10 @@ public static CommandStore create( final SequenceNumberFutureStore sequenceNumberFutureStore, final Map kafkaConsumerProperties, final Map kafkaProducerProperties, - final Duration commandQueueCatchupTimeout + final Duration commandQueueCatchupTimeout, + final Serializer commandIdSerializer, + final Serializer commandSerializer, + final Deserializer commandIdDeserializer ) { this.commandTopic = Objects.requireNonNull(commandTopic, "commandTopic"); this.commandStatusMap = Maps.newConcurrentMap(); @@ -140,6 +158,12 @@ public static CommandStore create( this.kafkaProducerProperties = Objects.requireNonNull(kafkaProducerProperties, "kafkaProducerProperties"); this.commandTopicName = Objects.requireNonNull(commandTopicName, "commandTopicName"); + this.commandIdSerializer = + Objects.requireNonNull(commandIdSerializer, "commandIdSerializer"); + this.commandSerializer = + Objects.requireNonNull(commandSerializer, "commandSerializer"); + this.commandIdDeserializer = + Objects.requireNonNull(commandIdDeserializer, "commandIdDeserializer"); } @Override @@ -210,22 +234,31 @@ public QueuedCommandStatus enqueueCommand( public List getNewCommands(final Duration timeout) { completeSatisfiedSequenceNumberFutures(); - final List queuedCommands = Lists.newArrayList(); - commandTopic.getNewCommands(timeout).forEach( - c -> { - if (c.value() != null) { - queuedCommands.add( - new QueuedCommand( - c.key(), - c.value(), - Optional.ofNullable(commandStatusMap.remove(c.key())), - c.offset() - ) - ); - } + final List commands = Lists.newArrayList(); + + final Iterable> records = commandTopic.getNewCommands(timeout); + for (ConsumerRecord record: records) { + if (record.value() != null) { + Optional commandStatusFuture = Optional.empty(); + try { + final CommandId commandId = + commandIdDeserializer.deserialize(commandTopicName, record.key()); + commandStatusFuture = Optional.ofNullable(commandStatusMap.remove(commandId)); + } catch (Exception e) { + LOG.warn( + "Error while attempting to fetch from commandStatusMap for key {}", + record.key(), + e); } - ); - return queuedCommands; + commands.add(new QueuedCommand( + record.key(), + record.value(), + commandStatusFuture, + record.offset())); + } + } + + return commands; } @Override @@ -263,8 +296,8 @@ public void ensureConsumedPast(final long seqNum, final Duration timeout) public Producer createTransactionalProducer() { return new KafkaProducer<>( kafkaProducerProperties, - InternalTopicSerdes.serializer(), - InternalTopicSerdes.serializer() + commandIdSerializer, + commandSerializer ); } @@ -291,10 +324,10 @@ private long getCommandTopicOffset() { COMMAND_TOPIC_PARTITION ); - try (Consumer commandConsumer = new KafkaConsumer<>( + try (Consumer commandConsumer = new KafkaConsumer<>( kafkaConsumerProperties, - InternalTopicSerdes.deserializer(CommandId.class), - InternalTopicSerdes.deserializer(Command.class) + new ByteArrayDeserializer(), + new ByteArrayDeserializer() )) { commandConsumer.assign(Collections.singleton(commandTopicPartition)); return commandConsumer.endOffsets(Collections.singletonList(commandTopicPartition)) diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutor.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutor.java index 8a670f3ca99c..6d1e5ff56fdb 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutor.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutor.java @@ -44,6 +44,8 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; + +import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.streams.StreamsConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,6 +63,7 @@ public class InteractiveStatementExecutor implements KsqlConfigurable { private final StatementParser statementParser; private final SpecificQueryIdGenerator queryIdGenerator; private final Map statusStore; + private final Deserializer commandDeserializer; private KsqlConfig ksqlConfig; private enum Mode { @@ -77,7 +80,8 @@ public InteractiveStatementExecutor( serviceContext, ksqlEngine, new StatementParser(ksqlEngine), - queryIdGenerator + queryIdGenerator, + InternalTopicSerdes.deserializer(Command.class) ); } @@ -86,12 +90,14 @@ public InteractiveStatementExecutor( final ServiceContext serviceContext, final KsqlEngine ksqlEngine, final StatementParser statementParser, - final SpecificQueryIdGenerator queryIdGenerator + final SpecificQueryIdGenerator queryIdGenerator, + final Deserializer commandDeserializer ) { this.serviceContext = Objects.requireNonNull(serviceContext, "serviceContext"); this.ksqlEngine = Objects.requireNonNull(ksqlEngine, "ksqlEngine"); this.statementParser = Objects.requireNonNull(statementParser, "statementParser"); this.queryIdGenerator = Objects.requireNonNull(queryIdGenerator, "queryIdGenerator"); + this.commandDeserializer = Objects.requireNonNull(commandDeserializer, "commandDeserializer"); this.statusStore = new ConcurrentHashMap<>(); } @@ -117,8 +123,8 @@ void handleStatement(final QueuedCommand queuedCommand) { throwIfNotConfigured(); handleStatementWithTerminatedQueries( - queuedCommand.getCommand(), - queuedCommand.getCommandId(), + queuedCommand.getAndDeserializeCommand(commandDeserializer), + queuedCommand.getAndDeserializeCommandId(), queuedCommand.getStatus(), Mode.EXECUTE, queuedCommand.getOffset() @@ -129,8 +135,8 @@ void handleRestore(final QueuedCommand queuedCommand) { throwIfNotConfigured(); handleStatementWithTerminatedQueries( - queuedCommand.getCommand(), - queuedCommand.getCommandId(), + queuedCommand.getAndDeserializeCommand(commandDeserializer), + queuedCommand.getAndDeserializeCommandId(), queuedCommand.getStatus(), Mode.RESTORE, queuedCommand.getOffset() diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/QueuedCommand.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/QueuedCommand.java index 695c4773313e..a43541fa128b 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/QueuedCommand.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/QueuedCommand.java @@ -15,21 +15,39 @@ package io.confluent.ksql.rest.server.computation; +import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.rest.entity.CommandId; +import java.util.Arrays; import java.util.Objects; import java.util.Optional; +import org.apache.kafka.common.serialization.Deserializer; public class QueuedCommand { - private final CommandId commandId; - private final Command command; + private final byte[] commandId; + private final byte[] command; private final Optional status; private final Long offset; + @VisibleForTesting public QueuedCommand( final CommandId commandId, final Command command, final Optional status, final Long offset + ) { + this( + InternalTopicSerdes.serializer().serialize("", commandId), + InternalTopicSerdes.serializer().serialize("", command), + status, + offset + ); + } + + public QueuedCommand( + final byte[] commandId, + final byte[] command, + final Optional status, + final Long offset ) { this.commandId = Objects.requireNonNull(commandId, "commandId"); this.command = Objects.requireNonNull(command,"command"); @@ -37,16 +55,26 @@ public QueuedCommand( this.offset = Objects.requireNonNull(offset, "offset"); } - public CommandId getCommandId() { - return commandId; + @VisibleForTesting + byte[] getCommandId() { + return Arrays.copyOf(commandId, commandId.length); } - public Optional getStatus() { - return status; + @VisibleForTesting + byte[] getCommand() { + return Arrays.copyOf(command, command.length); + } + + CommandId getAndDeserializeCommandId() { + return InternalTopicSerdes.deserializer(CommandId.class).deserialize("", commandId); } - public Command getCommand() { - return command; + Command getAndDeserializeCommand(final Deserializer deserializer) { + return deserializer.deserialize("", command); + } + + public Optional getStatus() { + return status; } public Long getOffset() { @@ -62,8 +90,8 @@ public boolean equals(final Object o) { return false; } final QueuedCommand that = (QueuedCommand) o; - return Objects.equals(commandId, that.commandId) - && Objects.equals(command, that.command) + return Arrays.equals(commandId, that.commandId) + && Arrays.equals(command, that.command) && Objects.equals(status, that.status) && Objects.equals(offset, that.offset); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/RestoreCommandsCompactor.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/RestoreCommandsCompactor.java index bb930876e795..8a5f3e79ab8a 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/RestoreCommandsCompactor.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/RestoreCommandsCompactor.java @@ -71,16 +71,19 @@ private static Set findTerminatedQueriesAndRemoveTerminateCommand final QueuedCommand cmd = it.next(); // Find known queries: - if (cmd.getCommand().getPlan().isPresent() - && cmd.getCommand().getPlan().get().getQueryPlan().isPresent() + final Command command = + cmd.getAndDeserializeCommand(InternalTopicSerdes.deserializer(Command.class)); + if (command.getPlan().isPresent() + && command.getPlan().get().getQueryPlan().isPresent() ) { - final QueryId queryId = cmd.getCommand().getPlan().get().getQueryPlan().get().getQueryId(); + final QueryId queryId = + command.getPlan().get().getQueryPlan().get().getQueryId(); queries.putIfAbsent(queryId, cmd); } // Find TERMINATE's that match known queries: - if (cmd.getCommandId().getType() == Type.TERMINATE) { - final QueryId queryId = new QueryId(cmd.getCommandId().getEntity()); + if (cmd.getAndDeserializeCommandId().getType() == Type.TERMINATE) { + final QueryId queryId = new QueryId(cmd.getAndDeserializeCommandId().getEntity()); final QueuedCommand terminated = queries.remove(queryId); if (terminated != null) { terminatedQueries.add(terminated); @@ -112,7 +115,8 @@ private static void removeQueryPlansOfTerminated(final List compa } private static Optional buildNewCmdWithoutQuery(final QueuedCommand cmd) { - final Command command = cmd.getCommand(); + final Command command = + cmd.getAndDeserializeCommand(InternalTopicSerdes.deserializer(Command.class)); if (!command.getPlan().isPresent() || !command.getPlan().get().getDdlCommand().isPresent()) { // No DDL command, so no command at all if we remove the query plan. (Likely INSERT INTO cmd). return Optional.empty(); @@ -127,8 +131,8 @@ private static Optional buildNewCmdWithoutQuery(final QueuedComma ); return Optional.of(new QueuedCommand( - cmd.getCommandId(), - newCommand, + InternalTopicSerdes.serializer().serialize("", cmd.getAndDeserializeCommandId()), + InternalTopicSerdes.serializer().serialize("", newCommand), cmd.getStatus(), cmd.getOffset() )); diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicBackupImplTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicBackupImplTest.java index 173b34f5e916..e88d9f49f412 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicBackupImplTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicBackupImplTest.java @@ -148,13 +148,13 @@ public void shouldCreateBackupLocationWhenDoesNotExist() throws IOException { } @Test - public void shouldWriteRecordsToReplayFile() throws IOException { + public void shouldWriteCommandToBackupToReplayFile() throws IOException { // Given commandTopicBackup.initialize(); // When final ConsumerRecord record = newConsumerRecord(command1); - commandTopicBackup.writeRecord(record); + commandTopicBackup.writeCommandToBackup(record); // Then final List> commands = @@ -169,13 +169,13 @@ public void shouldIgnoreRecordPreviouslyReplayed() throws IOException { // Given final ConsumerRecord record = newConsumerRecord(command1); commandTopicBackup.initialize(); - commandTopicBackup.writeRecord(record); + commandTopicBackup.writeCommandToBackup(record); final BackupReplayFile previousReplayFile = commandTopicBackup.getReplayFile(); // When // A 2nd initialize call will open the latest backup and read the previous replayed commands commandTopicBackup.initialize(); - commandTopicBackup.writeRecord(record); + commandTopicBackup.writeCommandToBackup(record); final BackupReplayFile currentReplayFile = commandTopicBackup.getReplayFile(); // Then @@ -191,7 +191,7 @@ public void shouldCreateNewReplayFileIfNewRecordsDoNotMatchPreviousBackups() thr // Given final ConsumerRecord record1 = newConsumerRecord(command1); commandTopicBackup.initialize(); - commandTopicBackup.writeRecord(record1); + commandTopicBackup.writeCommandToBackup(record1); final BackupReplayFile previousReplayFile = commandTopicBackup.getReplayFile(); // When @@ -201,7 +201,7 @@ public void shouldCreateNewReplayFileIfNewRecordsDoNotMatchPreviousBackups() thr // Need to increase the ticker so the new file has a new timestamp when(ticker.read()).thenReturn(2L); // The write command will create a new replay file with the new command - commandTopicBackup.writeRecord(record2); + commandTopicBackup.writeCommandToBackup(record2); final BackupReplayFile currentReplayFile = commandTopicBackup.getReplayFile(); // Then @@ -222,8 +222,8 @@ public void shouldWritePreviousReplayedRecordsAlreadyChecked() throws IOExceptio final ConsumerRecord record1 = newConsumerRecord(command1); final ConsumerRecord record2 = newConsumerRecord(command2); commandTopicBackup.initialize(); - commandTopicBackup.writeRecord(record1); - commandTopicBackup.writeRecord(record2); + commandTopicBackup.writeCommandToBackup(record1); + commandTopicBackup.writeCommandToBackup(record2); final BackupReplayFile previousReplayFile = commandTopicBackup.getReplayFile(); // When @@ -232,11 +232,11 @@ public void shouldWritePreviousReplayedRecordsAlreadyChecked() throws IOExceptio // Need to increase the ticker so the new file has a new timestamp when(ticker.read()).thenReturn(2L); // command1 is ignored because it was previously replayed - commandTopicBackup.writeRecord(record1); + commandTopicBackup.writeCommandToBackup(record1); // The write command will create a new replay file with the new command, and command1 will // be written to have a complete backup final ConsumerRecord record3 = newConsumerRecord(command3); - commandTopicBackup.writeRecord(record3); + commandTopicBackup.writeCommandToBackup(record3); final BackupReplayFile currentReplayFile = commandTopicBackup.getReplayFile(); // Then diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicTest.java index df07739a6871..a88cf2a38c3d 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/CommandTopicTest.java @@ -26,19 +26,21 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.confluent.ksql.rest.entity.CommandId; -import io.confluent.ksql.rest.server.computation.Command; -import io.confluent.ksql.rest.server.computation.QueuedCommand; + +import java.nio.charset.Charset; import java.time.Duration; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Optional; + +import io.confluent.ksql.rest.server.computation.QueuedCommand; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.common.TopicPartition; +import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -54,27 +56,27 @@ public class CommandTopicTest { private static final String COMMAND_TOPIC_NAME = "foo"; @Mock - private Consumer commandConsumer; + private Consumer commandConsumer; private CommandTopic commandTopic; @Mock private CommandTopicBackup commandTopicBackup; - @Mock - private CommandId commandId1; - @Mock - private Command command1; - @Mock - private CommandId commandId2; - @Mock - private Command command2; - @Mock - private CommandId commandId3; - @Mock - private Command command3; + + private final byte[] commandId1 = "commandId1".getBytes(Charset.defaultCharset()); + private final byte[] command1 = "command1".getBytes(Charset.defaultCharset()); + private final byte[] commandId2 = "commandId2".getBytes(Charset.defaultCharset()); + private final byte[] command2 = "command2".getBytes(Charset.defaultCharset()); + private final byte[] commandId3 = "commandId3".getBytes(Charset.defaultCharset()); + private final byte[] command3 = "command3".getBytes(Charset.defaultCharset()); + + private ConsumerRecord record1; + private ConsumerRecord record2; + private ConsumerRecord record3; + @Mock - private ConsumerRecords consumerRecords; + private ConsumerRecords consumerRecords; @Captor private ArgumentCaptor> topicPartitionsCaptor; @@ -83,6 +85,9 @@ public class CommandTopicTest { @Before @SuppressWarnings("unchecked") public void setup() { + record1 = new ConsumerRecord<>("topic", 0, 0, commandId1, command1); + record2 = new ConsumerRecord<>("topic", 0, 1, commandId2, command2); + record3 = new ConsumerRecord<>("topic", 0, 2, commandId3, command3); commandTopic = new CommandTopic(COMMAND_TOPIC_NAME, commandConsumer, commandTopicBackup); } @@ -102,7 +107,7 @@ public void shouldGetNewCommandsIteratorCorrectly() { when(commandConsumer.poll(any(Duration.class))).thenReturn(consumerRecords); // When: - final Iterable> newCommands = commandTopic + final Iterable> newCommands = commandTopic .getNewCommands(Duration.ofHours(1)); // Then: @@ -114,10 +119,10 @@ public void shouldGetRestoreCommandsCorrectly() { // Given: when(commandConsumer.poll(any(Duration.class))) .thenReturn(someConsumerRecords( - new ConsumerRecord<>("topic", 0, 0, commandId1, command1), - new ConsumerRecord<>("topic", 0, 1, commandId2, command2))) + record1, + record2)) .thenReturn(someConsumerRecords( - new ConsumerRecord<>("topic", 0, 2, commandId3, command3))) + record3)) .thenReturn(new ConsumerRecords<>(Collections.emptyMap())); // When: @@ -130,7 +135,7 @@ public void shouldGetRestoreCommandsCorrectly() { equalTo(Collections.singletonList(new TopicPartition(COMMAND_TOPIC_NAME, 0)))); assertThat(queuedCommandList, equalTo(ImmutableList.of( new QueuedCommand(commandId1, command1, Optional.empty(), 0L), - new QueuedCommand(commandId2, command2, Optional.empty(), 1L), + new QueuedCommand(commandId2, command2, Optional.empty(),1L), new QueuedCommand(commandId3, command3, Optional.empty(), 2L)))); } @@ -138,25 +143,25 @@ public void shouldGetRestoreCommandsCorrectly() { public void shouldHaveOffsetsInQueuedCommands() { // Given: when(commandConsumer.poll(any(Duration.class))) - .thenReturn(someConsumerRecords( - new ConsumerRecord<>("topic", 0, 0, commandId1, command1), - new ConsumerRecord<>("topic", 0, 1, commandId2, command2))) - .thenReturn(someConsumerRecords( - new ConsumerRecord<>("topic", 0, 2, commandId3, command3))) - .thenReturn(new ConsumerRecords<>(Collections.emptyMap())); + .thenReturn(someConsumerRecords( + new ConsumerRecord<>("topic", 0, 0, commandId1, command1), + new ConsumerRecord<>("topic", 0, 1, commandId2, command2))) + .thenReturn(someConsumerRecords( + new ConsumerRecord<>("topic", 0, 2, commandId3, command3))) + .thenReturn(new ConsumerRecords<>(Collections.emptyMap())); // When: final List queuedCommandList = commandTopic - .getRestoreCommands(Duration.ofMillis(1)); + .getRestoreCommands(Duration.ofMillis(1)); // Then: verify(commandConsumer).seekToBeginning(topicPartitionsCaptor.capture()); assertThat(topicPartitionsCaptor.getValue(), - equalTo(Collections.singletonList(new TopicPartition(COMMAND_TOPIC_NAME, 0)))); + equalTo(Collections.singletonList(new TopicPartition(COMMAND_TOPIC_NAME, 0)))); assertThat(queuedCommandList, equalTo(ImmutableList.of( - new QueuedCommand(commandId1, command1, Optional.empty(), 0L), - new QueuedCommand(commandId2, command2, Optional.empty(),1L), - new QueuedCommand(commandId3, command3, Optional.empty(), 2L)))); + new QueuedCommand(commandId1, command1, Optional.empty(), 0L), + new QueuedCommand(commandId2, command2, Optional.empty(),1L), + new QueuedCommand(commandId3, command3, Optional.empty(), 2L)))); } @Test @@ -183,26 +188,25 @@ public void shouldGetRestoreCommandsCorrectlyWithDuplicateKeys() { new QueuedCommand(commandId3, command3, Optional.empty(), 3L)))); } - @Test public void shouldFilterNullCommandsWhileRestoringCommands() { // Given: when(commandConsumer.poll(any(Duration.class))) .thenReturn(someConsumerRecords( - new ConsumerRecord<>("topic", 0, 0, commandId1, command1), - new ConsumerRecord<>("topic", 0, 1, commandId2, command2), + record1, + record2, new ConsumerRecord<>("topic", 0, 2, commandId2, null) )) .thenReturn(new ConsumerRecords<>(Collections.emptyMap())); // When: - final List queuedCommandList = commandTopic + final List recordList = commandTopic .getRestoreCommands(Duration.ofMillis(1)); // Then: - assertThat(queuedCommandList, equalTo(ImmutableList.of( + assertThat(recordList, equalTo(ImmutableList.of( new QueuedCommand(commandId1, command1, Optional.empty(), 0L), - new QueuedCommand(commandId2, command2, Optional.empty(), 1L)))); + new QueuedCommand(commandId2, command2, Optional.empty(),1L)))); } @Test @@ -226,10 +230,7 @@ public void shouldCloseAllResources() { @Test public void shouldHaveAllCreateCommandsInOrder() { // Given: - final ConsumerRecords records = someConsumerRecords( - new ConsumerRecord<>("topic", 0, 0, commandId1, command1), - new ConsumerRecord<>("topic", 0, 1, commandId2, command2), - new ConsumerRecord<>("topic", 0, 2, commandId3, command3)); + final ConsumerRecords records = someConsumerRecords(record1, record2, record3); when(commandTopic.getNewCommands(any())) .thenReturn(records) @@ -242,7 +243,7 @@ public void shouldHaveAllCreateCommandsInOrder() { assertThat(commands, equalTo(Arrays.asList( new QueuedCommand(commandId1, command1, Optional.empty(), 0L), new QueuedCommand(commandId2, command2, Optional.empty(), 1L), - new QueuedCommand(commandId3, command3, Optional.empty(), 2L) + new QueuedCommand(commandId3, command3, Optional.empty(), 2L) ))); } @@ -267,10 +268,6 @@ public void shouldCloseCommandTopicBackup() { @Test public void shouldBackupRestoreCommands() { // Given - final ConsumerRecord record1 = - new ConsumerRecord<>("topic", 0, 0, commandId1, command1); - final ConsumerRecord record2 = - new ConsumerRecord<>("topic", 0, 0, commandId2, command2); when(commandConsumer.poll(any(Duration.class))) .thenReturn(someConsumerRecords(record1, record2)) .thenReturn(new ConsumerRecords<>(Collections.emptyMap())); @@ -288,10 +285,6 @@ public void shouldBackupRestoreCommands() { @Test public void shouldBackupNewCommands() { // Given - final ConsumerRecord record1 = - new ConsumerRecord<>("topic", 0, 0, commandId1, command1); - final ConsumerRecord record2 = - new ConsumerRecord<>("topic", 0, 1, commandId2, command2); when(commandConsumer.poll(any(Duration.class))) .thenReturn(someConsumerRecords(record1, record2)) .thenReturn(new ConsumerRecords<>(Collections.emptyMap())); @@ -322,8 +315,8 @@ public void shouldGetEndOffsetCorrectly() { @SuppressWarnings("varargs") @SafeVarargs - private static ConsumerRecords someConsumerRecords( - final ConsumerRecord... consumerRecords + private static ConsumerRecords someConsumerRecords( + final ConsumerRecord... consumerRecords ) { return new ConsumerRecords<>( ImmutableMap.of(TOPIC_PARTITION, ImmutableList.copyOf(consumerRecords))); diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetricTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetricTest.java index 0e56d3da3362..e4ed583c6499 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetricTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerStatusMetricTest.java @@ -92,6 +92,15 @@ public void shouldUpdateToErrorState() { assertThat(currentGaugeValue(), is(CommandRunner.CommandRunnerStatus.ERROR.name())); } + @Test + public void shouldUpdateToDegradedState() { + // When: + when(commandRunner.checkCommandRunnerStatus()).thenReturn(CommandRunner.CommandRunnerStatus.DEGRADED); + + // Then: + assertThat(currentGaugeValue(), is(CommandRunner.CommandRunnerStatus.DEGRADED.name())); + } + @Test public void shouldRemoveMetricOnClose() { // When: diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerTest.java index 7799c2eff2da..f77e68b69e8a 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandRunnerTest.java @@ -25,6 +25,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.never; @@ -43,11 +44,20 @@ import java.time.Duration; import java.time.Instant; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Collectors; + +import io.confluent.ksql.util.Pair; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -88,6 +98,10 @@ public class CommandRunnerTest { private ExecutorService executor; @Mock private Function, List> compactor; + @Mock + private Consumer incompatibleCommandChecker; + @Mock + private Deserializer commandDeserializer; @Captor private ArgumentCaptor threadTaskCaptor; private CommandRunner commandRunner; @@ -101,9 +115,12 @@ public void setup() { when(clusterTerminate.getStatement()) .thenReturn(TerminateCluster.TERMINATE_CLUSTER_STATEMENT_TEXT); - when(queuedCommand1.getCommand()).thenReturn(command); - when(queuedCommand2.getCommand()).thenReturn(command); - when(queuedCommand3.getCommand()).thenReturn(command); + when(queuedCommand1.getAndDeserializeCommand(commandDeserializer)).thenReturn(command); + when(queuedCommand2.getAndDeserializeCommand(commandDeserializer)).thenReturn(command); + when(queuedCommand3.getAndDeserializeCommand(commandDeserializer)).thenReturn(command); + doNothing().when(incompatibleCommandChecker).accept(queuedCommand1); + doNothing().when(incompatibleCommandChecker).accept(queuedCommand2); + doNothing().when(incompatibleCommandChecker).accept(queuedCommand3); when(compactor.apply(any())).thenAnswer(inv -> inv.getArgument(0)); @@ -120,7 +137,9 @@ public void setup() { Duration.ofMillis(COMMAND_RUNNER_HEALTH_TIMEOUT), "", clock, - compactor + compactor, + incompatibleCommandChecker, + commandDeserializer ); } @@ -142,8 +161,8 @@ public void shouldRunThePriorCommandsCorrectly() { @Test public void shouldRunThePriorCommandsWithTerminateCorrectly() { // Given: - givenQueuedCommands(queuedCommand1); - when(queuedCommand1.getCommand()).thenReturn(clusterTerminate); + givenQueuedCommands(queuedCommand1, queuedCommand2, queuedCommand3); + when(queuedCommand1.getAndDeserializeCommand(commandDeserializer)).thenReturn(clusterTerminate); // When: commandRunner.processPriorCommands(); @@ -161,7 +180,7 @@ public void shouldRunThePriorCommandsWithTerminateCorrectly() { public void shouldEarlyOutIfRestoreContainsTerminate() { // Given: givenQueuedCommands(queuedCommand1, queuedCommand2, queuedCommand3); - when(queuedCommand2.getCommand()).thenReturn(clusterTerminate); + when(queuedCommand2.getAndDeserializeCommand(commandDeserializer)).thenReturn(clusterTerminate); // When: commandRunner.processPriorCommands(); @@ -198,6 +217,41 @@ public void shouldOnlyRestoreCompacted() { verify(statementExecutor, never()).handleRestore(queuedCommand2); } + @Test + public void shouldProcessPartialListOfCommandsOnDeserializationExceptionInRestore() { + // Given: + givenQueuedCommands(queuedCommand1, queuedCommand2, queuedCommand3); + doThrow(new SerializationException()).when(incompatibleCommandChecker).accept(queuedCommand3); + + // When: + commandRunner.processPriorCommands(); + + // Then: + final InOrder inOrder = inOrder(statementExecutor); + inOrder.verify(statementExecutor).handleRestore(eq(queuedCommand1)); + inOrder.verify(statementExecutor).handleRestore(eq(queuedCommand2)); + + assertThat(commandRunner.checkCommandRunnerStatus(), is(CommandRunner.CommandRunnerStatus.DEGRADED)); + verify(statementExecutor, never()).handleRestore(queuedCommand3); + + } + + @Test + public void shouldProcessPartialListOfCommandsOnDeserializationExceptionInFetch() { + // Given: + givenQueuedCommands(queuedCommand1, queuedCommand2, queuedCommand3); + doThrow(new SerializationException()).when(incompatibleCommandChecker).accept(queuedCommand2); + + // When: + commandRunner.processPriorCommands(); + + // Then: + verify(statementExecutor).handleRestore(eq(queuedCommand1)); + verify(statementExecutor, never()).handleRestore(queuedCommand2); + verify(statementExecutor, never()).handleRestore(queuedCommand3); + assertThat(commandRunner.checkCommandRunnerStatus(), is(CommandRunner.CommandRunnerStatus.DEGRADED)); + } + @Test public void shouldPullAndRunStatements() { // Given: @@ -248,7 +302,7 @@ public void shouldThrowExceptionIfOverMaxRetries() { public void shouldEarlyOutIfNewCommandsContainsTerminate() { // Given: givenQueuedCommands(queuedCommand1, queuedCommand2, queuedCommand3); - when(queuedCommand2.getCommand()).thenReturn(clusterTerminate); + when(queuedCommand2.getAndDeserializeCommand(commandDeserializer)).thenReturn(clusterTerminate); // When: commandRunner.fetchAndRunCommands(); @@ -346,6 +400,49 @@ public void shouldSubmitTaskOnStart() { inOrder.verify(executor).shutdown(); } + @Test + public void shouldNotStartCommandRunnerThreadIfSerializationExceptionInRestore() throws Exception { + // Given: + givenQueuedCommands(queuedCommand1, queuedCommand2, queuedCommand3); + doThrow(new SerializationException()).when(incompatibleCommandChecker).accept(queuedCommand3); + + // When: + commandRunner.processPriorCommands(); + commandRunner.start(); + + final Runnable threadTask = getThreadTask(); + threadTask.run(); + + // Then: + final InOrder inOrder = inOrder(executor, commandStore); + inOrder.verify(commandStore).wakeup(); + inOrder.verify(executor).awaitTermination(anyLong(), any()); + inOrder.verify(commandStore).close(); + verify(commandStore, never()).getNewCommands(any()); + verify(statementExecutor, times(2)).handleRestore(any()); + } + + @Test + public void shouldCloseEarlyWhenSerializationExceptionInFetch() throws Exception { + // Given: + when(commandStore.getNewCommands(any())) + .thenReturn(Collections.singletonList(queuedCommand1)) + .thenReturn(Collections.singletonList(queuedCommand2)); + doThrow(new SerializationException()).when(incompatibleCommandChecker).accept(queuedCommand2); + + // When: + commandRunner.start(); + verify(commandStore, never()).close(); + final Runnable threadTask = getThreadTask(); + threadTask.run(); + + // Then: + final InOrder inOrder = inOrder(executor, commandStore); + inOrder.verify(commandStore).wakeup(); + inOrder.verify(executor).awaitTermination(anyLong(), any()); + inOrder.verify(commandStore).close(); + } + @Test public void shouldCloseTheCommandRunnerCorrectly() throws Exception { // Given: diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandStoreTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandStoreTest.java index 1b8ca47bf770..2f534db14572 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandStoreTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/CommandStoreTest.java @@ -28,6 +28,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -40,10 +41,13 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; + +import io.confluent.ksql.util.Pair; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.producer.Producer; @@ -51,6 +55,8 @@ import org.apache.kafka.clients.producer.RecordMetadata; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -78,6 +84,12 @@ public class CommandStoreTest { private CommandTopic commandTopic; @Mock private Producer transactionalProducer; + @Mock + private Serializer commandIdSerializer; + @Mock + private Serializer commandSerializer; + @Mock + private Deserializer commandIdDeserializer; private final CommandId commandId = new CommandId(CommandId.Type.STREAM, "foo", CommandId.Action.CREATE); @@ -135,7 +147,10 @@ public void setUp() { sequenceNumberFutureStore, Collections.emptyMap(), Collections.emptyMap(), - TIMEOUT + TIMEOUT, + commandIdSerializer, + commandSerializer, + commandIdDeserializer ); } @@ -176,6 +191,7 @@ public void shouldCleanupCommandStatusOnProduceError() { @Test public void shouldEnqueueNewAfterHandlingExistingCommand() { // Given: + when(commandIdDeserializer.deserialize(any(), any())).thenReturn(commandId); commandStore.enqueueCommand(commandId, command, transactionalProducer); commandStore.getNewCommands(NEW_CMDS_TIMEOUT); @@ -188,16 +204,19 @@ public void shouldRegisterBeforeDistributeAndReturnStatusOnGetNewCommands() { // Given: when(transactionalProducer.send(any(ProducerRecord.class))).thenAnswer( invocation -> { - final QueuedCommand queuedCommand = commandStore.getNewCommands(NEW_CMDS_TIMEOUT).get(0); - assertThat(queuedCommand.getCommandId(), equalTo(commandId)); - assertThat(queuedCommand.getStatus().isPresent(), equalTo(true)); + final QueuedCommand command = commandStore.getNewCommands(NEW_CMDS_TIMEOUT).get(0); assertThat( - queuedCommand.getStatus().get().getStatus().getStatus(), + command.getAndDeserializeCommandId(), + equalTo(commandId)); + assertThat(command.getStatus().isPresent(), equalTo(true)); + assertThat( + command.getStatus().get().getStatus().getStatus(), equalTo(CommandStatus.Status.QUEUED)); - assertThat(queuedCommand.getOffset(), equalTo(0L)); + assertThat(command.getOffset(), equalTo(0L)); return testFuture; } ); + when(commandIdDeserializer.deserialize(any(),any())).thenReturn(commandId); // When: commandStore.enqueueCommand(commandId, command, transactionalProducer); @@ -209,21 +228,27 @@ public void shouldRegisterBeforeDistributeAndReturnStatusOnGetNewCommands() { @Test public void shouldFilterNullCommands() { // Given: - final ConsumerRecords records = buildRecords( + final ConsumerRecords records = buildRecords( commandId, null, commandId, command); + final Deserializer commandDeserializer = mock(Deserializer.class); + when(commandDeserializer.deserialize(any(), any())).thenReturn(command); when(commandTopic.getNewCommands(any())).thenReturn(records); - // When: - final List commands = commandStore.getNewCommands(NEW_CMDS_TIMEOUT); + // When: + final List commands = + commandStore.getNewCommands(NEW_CMDS_TIMEOUT); // Then: assertThat(commands, hasSize(1)); - assertThat(commands.get(0).getCommandId(), equalTo(commandId)); - assertThat(commands.get(0).getCommand(), equalTo(command)); + assertThat( + commands.get(0).getAndDeserializeCommandId(), + equalTo(commandId)); + assertThat( + commands.get(0).getAndDeserializeCommand(commandDeserializer), + equalTo(command)); } - @Test public void shouldDistributeCommand() { when(transactionalProducer.send(any(ProducerRecord.class))).thenReturn(testFuture); @@ -342,14 +367,20 @@ public void shouldStartCommandTopicOnStart() { verify(commandTopic).start(); } - private static ConsumerRecords buildRecords(final Object... args) { + private static ConsumerRecords buildRecords(final Object... args) { assertThat(args.length % 2, equalTo(0)); - final List> records = new ArrayList<>(); + final List> records = new ArrayList<>(); for (int i = 0; i < args.length; i += 2) { assertThat(args[i], instanceOf(CommandId.class)); assertThat(args[i + 1], anyOf(is(nullValue()), instanceOf(Command.class))); + records.add( - new ConsumerRecord<>(COMMAND_TOPIC_NAME, 0, 0, (CommandId) args[i], (Command) args[i + 1])); + new ConsumerRecord<>( + COMMAND_TOPIC_NAME, + 0, + 0, + InternalTopicSerdes.serializer().serialize(null, args[i]), + args[i + 1] == null ? null : InternalTopicSerdes.serializer().serialize(null, args[i + 1]))); } return new ConsumerRecords<>(Collections.singletonMap(COMMAND_TOPIC_PARTITION, records)); } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutorTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutorTest.java index 3de675f85e50..64acfbc3c99a 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutorTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/InteractiveStatementExecutorTest.java @@ -75,6 +75,7 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; import kafka.zookeeper.ZooKeeperClientException; +import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.streams.StreamsConfig; import org.hamcrest.CoreMatchers; import org.junit.After; @@ -121,6 +122,8 @@ public class InteractiveStatementExecutorTest { private KsqlPlan plan; @Mock private CommandStatusFuture status; + @Mock + private Deserializer commandDeserializer; private Command plannedCommand; @@ -150,13 +153,15 @@ public void setUp() { serviceContext, ksqlEngine, statementParser, - hybridQueryIdGenerator + hybridQueryIdGenerator, + InternalTopicSerdes.deserializer(Command.class) ); statementExecutorWithMocks = new InteractiveStatementExecutor( serviceContext, mockEngine, mockParser, - mockQueryIdGenerator + mockQueryIdGenerator, + commandDeserializer ); statementExecutor.configure(ksqlConfig); @@ -199,7 +204,8 @@ public void shouldThrowOnHandleStatementIfNotConfigured() { serviceContext, mockEngine, mockParser, - mockQueryIdGenerator + mockQueryIdGenerator, + commandDeserializer ); // When: @@ -213,7 +219,8 @@ public void shouldThrowOnHandleRestoreIfNotConfigured() { serviceContext, mockEngine, mockParser, - mockQueryIdGenerator + mockQueryIdGenerator, + commandDeserializer ); // When: @@ -586,11 +593,14 @@ public void shouldSetNextQueryIdToNextOffsetWhenExecutingRestoreCommand() { // Given: mockReplayCSAS(new QueryId("csas-query-id")); + final Command command = new Command("CSAS", emptyMap(), emptyMap(), Optional.empty()); + when(commandDeserializer.deserialize(any(), any())).thenReturn(command); + // When: statementExecutorWithMocks.handleRestore( new QueuedCommand( new CommandId(Type.STREAM, "foo", Action.CREATE), - new Command("CSAS", emptyMap(), emptyMap(), Optional.empty()), + command, Optional.empty(), 2L ) @@ -606,12 +616,14 @@ public void shouldSkipStartWhenReplayingLog() { final QueryId queryId = new QueryId("csas-query-id"); final String name = "foo"; final PersistentQueryMetadata mockQuery = mockReplayCSAS(queryId); + final Command command = new Command("CSAS", emptyMap(), emptyMap(), Optional.empty()); + when(commandDeserializer.deserialize(any(), any())).thenReturn(command); // When: statementExecutorWithMocks.handleRestore( new QueuedCommand( new CommandId(Type.STREAM, name, Action.CREATE), - new Command("CSAS", emptyMap(), emptyMap(), Optional.empty()), + command, Optional.empty(), 0L ) @@ -672,11 +684,14 @@ public void shouldNotCascadeDropStreamCommand() { when(mockEngine.execute(eq(serviceContext), eqConfiguredPlan(plan))) .thenReturn(ExecuteResult.of("SUCCESS")); + final Command command = new Command(drop, emptyMap(), emptyMap(), Optional.empty()); + when(commandDeserializer.deserialize(any(), any())).thenReturn(command); + // When: statementExecutorWithMocks.handleRestore( new QueuedCommand( new CommandId(Type.STREAM, "foo", Action.DROP), - new Command(drop, emptyMap(), emptyMap(), Optional.empty()), + command, Optional.empty(), 0L ) @@ -699,11 +714,14 @@ public void shouldTerminateAll() { when(mockEngine.getPersistentQueries()).thenReturn(ImmutableList.of(query0, query1)); + final Command command = new Command("terminate all", emptyMap(), emptyMap(), Optional.empty()); + when(commandDeserializer.deserialize(any(), any())).thenReturn(command); + // When: statementExecutorWithMocks.handleStatement( new QueuedCommand( new CommandId(Type.TERMINATE, "-", Action.EXECUTE), - new Command("terminate all", emptyMap(), emptyMap(), Optional.empty()), + command, Optional.empty(), 0L ) @@ -731,9 +749,12 @@ public void shouldDoIdempotentTerminate() { .thenReturn(Optional.of(query)) .thenReturn(Optional.empty()); + final Command command = new Command("terminate all", emptyMap(), emptyMap(), Optional.empty()); + when(commandDeserializer.deserialize(any(), any())).thenReturn(command); + final QueuedCommand cmd = new QueuedCommand( new CommandId(Type.TERMINATE, "-", Action.EXECUTE), - new Command("terminate all", emptyMap(), emptyMap(), Optional.empty()), + command, Optional.empty(), 0L ); @@ -922,13 +943,17 @@ private void handleStatement( handleStatement(statementExecutor, command, commandId, commandStatus, offset); } - private static void handleStatement( + private void handleStatement( final InteractiveStatementExecutor statementExecutor, final Command command, final CommandId commandId, final Optional commandStatus, final long offset) { - statementExecutor.handleStatement(new QueuedCommand(commandId, command, commandStatus, offset)); + when(queuedCommand.getAndDeserializeCommand(any())).thenReturn(command); + when(queuedCommand.getAndDeserializeCommandId()).thenReturn(commandId); + when(queuedCommand.getStatus()).thenReturn(commandStatus); + when(queuedCommand.getOffset()).thenReturn(offset); + statementExecutor.handleStatement(queuedCommand); } private void terminateQueries() { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/RecoveryTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/RecoveryTest.java index ed3f93a89fab..77152df78d91 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/RecoveryTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/computation/RecoveryTest.java @@ -221,7 +221,8 @@ private class KsqlServer { serverState, "ksql-service-id", Duration.ofMillis(2000), - "" + "", + InternalTopicSerdes.deserializer(Command.class) ); this.ksqlResource = new KsqlResource(