Skip to content

Commit

Permalink
KAFKA-2763: better stream task assignment
Browse files Browse the repository at this point in the history
guozhangwang

When the rebalance happens each consumer reports the following information to the coordinator.
* Client UUID (a unique id assigned to an instance of KafkaStreaming)
* Task ids of previously running tasks
* Task ids of valid local states on the client's state directory

TaskAssignor does the following
* Assign a task to a client which was running it previously. If there is no such client, assign a task to a client which has its valid local state.
* Try to balance the load among stream threads.
  * A client may have more than one stream threads. The assignor tries to assign tasks to a client proportionally to the number of threads.

Author: Yasuhiro Matsuda <yasuhiro@confluent.io>

Reviewers: Guozhang Wang

Closes #497 from ymatsuda/task_assignment
  • Loading branch information
Yasuhiro Matsuda authored and Geoff Anderson committed Nov 18, 2015
1 parent 921fc42 commit 0d3def0
Show file tree
Hide file tree
Showing 16 changed files with 1,464 additions and 70 deletions.
Expand Up @@ -29,6 +29,7 @@
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -85,11 +86,13 @@ public class KafkaStreaming {
private final StreamThread[] threads;

private String clientId;
private final UUID uuid;
private final Metrics metrics;

public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Exception {
// create the metrics
this.time = new SystemTime();
this.uuid = UUID.randomUUID();

MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamingConfig.METRICS_NUM_SAMPLES_CONFIG))
.timeWindow(config.getLong(StreamingConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG),
Expand All @@ -104,7 +107,7 @@ public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Ex

this.threads = new StreamThread[config.getInt(StreamingConfig.NUM_STREAM_THREADS_CONFIG)];
for (int i = 0; i < this.threads.length; i++) {
this.threads[i] = new StreamThread(builder, config, this.clientId, this.metrics, this.time);
this.threads[i] = new StreamThread(builder, config, this.clientId, this.uuid, this.metrics, this.time);
}
}

Expand Down
Expand Up @@ -27,8 +27,8 @@
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor;
import org.apache.kafka.streams.processor.internals.StreamThread;

import java.util.Map;

Expand Down Expand Up @@ -205,16 +205,16 @@ public class StreamingConfig extends AbstractConfig {
}

public static class InternalConfig {
public static final String PARTITION_GROUPER_INSTANCE = "__partition.grouper.instance__";
public static final String STREAM_THREAD_INSTANCE = "__stream.thread.instance__";
}

public StreamingConfig(Map<?, ?> props) {
super(CONFIG, props);
}

public Map<String, Object> getConsumerConfigs(PartitionGrouper partitionGrouper) {
public Map<String, Object> getConsumerConfigs(StreamThread streamThread) {
Map<String, Object> props = getConsumerConfigs();
props.put(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper);
props.put(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, streamThread);
props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, KafkaStreamingPartitionAssignor.class.getName());
return props;
}
Expand Down
Expand Up @@ -50,4 +50,8 @@ public Set<TaskId> taskIds(TopicPartition partition) {
return partitionAssignor.taskIds(partition);
}

public Set<TaskId> standbyTasks() {
return partitionAssignor.standbyTasks();
}

}
Expand Up @@ -17,7 +17,9 @@

package org.apache.kafka.streams.processor;

public class TaskId {
import java.nio.ByteBuffer;

public class TaskId implements Comparable<TaskId> {

public final int topicGroupId;
public final int partition;
Expand Down Expand Up @@ -45,6 +47,15 @@ public static TaskId parse(String string) {
}
}

public void writeTo(ByteBuffer buf) {
buf.putInt(topicGroupId);
buf.putInt(partition);
}

public static TaskId readFrom(ByteBuffer buf) {
return new TaskId(buf.getInt(), buf.getInt());
}

@Override
public boolean equals(Object o) {
if (o instanceof TaskId) {
Expand All @@ -61,6 +72,16 @@ public int hashCode() {
return (int) (n % 0xFFFFFFFFL);
}

@Override
public int compareTo(TaskId other) {
return
this.topicGroupId < other.topicGroupId ? -1 :
(this.topicGroupId > other.topicGroupId ? 1 :
(this.partition < other.partition ? -1 :
(this.partition > other.partition ? 1 :
0)));
}

public static class TaskIdFormatException extends RuntimeException {
}
}
Expand Up @@ -23,37 +23,49 @@
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.StreamingConfig;
import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
import org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentException;
import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Configurable {

private static final Logger log = LoggerFactory.getLogger(KafkaStreamingPartitionAssignor.class);

private PartitionGrouper partitionGrouper;
private StreamThread streamThread;
private Map<TopicPartition, Set<TaskId>> partitionToTaskIds;
private Set<TaskId> standbyTasks;

@Override
public void configure(Map<String, ?> configs) {
Object o = configs.get(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE);
if (o == null)
throw new KafkaException("PartitionGrouper is not specified");
Object o = configs.get(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE);
if (o == null) {
KafkaException ex = new KafkaException("StreamThread is not specified");
log.error(ex.getMessage(), ex);
throw ex;
}

if (!PartitionGrouper.class.isInstance(o))
throw new KafkaException(o.getClass().getName() + " is not an instance of " + PartitionGrouper.class.getName());
if (!(o instanceof StreamThread)) {
KafkaException ex = new KafkaException(o.getClass().getName() + " is not an instance of " + StreamThread.class.getName());
log.error(ex.getMessage(), ex);
throw ex;
}

partitionGrouper = (PartitionGrouper) o;
partitionGrouper.partitionAssignor(this);
streamThread = (StreamThread) o;
streamThread.partitionGrouper.partitionAssignor(this);
}

@Override
Expand All @@ -63,38 +75,110 @@ public String name() {

@Override
public Subscription subscription(Set<String> topics) {
return new Subscription(new ArrayList<>(topics));
// Adds the following information to subscription
// 1. Client UUID (a unique id assigned to an instance of KafkaStreaming)
// 2. Task ids of previously running tasks
// 3. Task ids of valid local states on the client's state directory.

Set<TaskId> prevTasks = streamThread.prevTasks();
Set<TaskId> standbyTasks = streamThread.cachedTasks();
standbyTasks.removeAll(prevTasks);
SubscriptionInfo data = new SubscriptionInfo(streamThread.clientUUID, prevTasks, standbyTasks);

return new Subscription(new ArrayList<>(topics), data.encode());
}

@Override
public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions) {
Map<TaskId, Set<TopicPartition>> partitionGroups = partitionGrouper.partitionGroups(metadata);
// This assigns tasks to consumer clients in two steps.
// 1. using TaskAssignor tasks are assigned to streaming clients.
// - Assign a task to a client which was running it previously.
// If there is no such client, assign a task to a client which has its valid local state.
// - A client may have more than one stream threads.
// The assignor tries to assign tasks to a client proportionally to the number of threads.
// - We try not to assign the same set of tasks to two different clients
// We do the assignment in one-pass. The result may not satisfy above all.
// 2. within each client, tasks are assigned to consumer clients in round-robin manner.

Map<UUID, Set<String>> consumersByClient = new HashMap<>();
Map<UUID, ClientState<TaskId>> states = new HashMap<>();

// Decode subscription info
for (Map.Entry<String, Subscription> entry : subscriptions.entrySet()) {
String consumerId = entry.getKey();
Subscription subscription = entry.getValue();

SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData());

Set<String> consumers = consumersByClient.get(info.clientUUID);
if (consumers == null) {
consumers = new HashSet<>();
consumersByClient.put(info.clientUUID, consumers);
}
consumers.add(consumerId);

ClientState<TaskId> state = states.get(info.clientUUID);
if (state == null) {
state = new ClientState<>();
states.put(info.clientUUID, state);
}

state.prevActiveTasks.addAll(info.prevTasks);
state.prevAssignedTasks.addAll(info.prevTasks);
state.prevAssignedTasks.addAll(info.standbyTasks);
state.capacity = state.capacity + 1d;
}

String[] clientIds = subscriptions.keySet().toArray(new String[subscriptions.size()]);
TaskId[] taskIds = partitionGroups.keySet().toArray(new TaskId[partitionGroups.size()]);
// Get partition groups from the partition grouper
Map<TaskId, Set<TopicPartition>> partitionGroups = streamThread.partitionGrouper.partitionGroups(metadata);

states = TaskAssignor.assign(states, partitionGroups.keySet(), 0); // TODO: enable standby tasks
Map<String, Assignment> assignment = new HashMap<>();

for (int i = 0; i < clientIds.length; i++) {
List<TopicPartition> partitions = new ArrayList<>();
List<TaskId> ids = new ArrayList<>();
for (int j = i; j < taskIds.length; j += clientIds.length) {
TaskId taskId = taskIds[j];
for (TopicPartition partition : partitionGroups.get(taskId)) {
partitions.add(partition);
ids.add(taskId);
}
for (Map.Entry<UUID, Set<String>> entry : consumersByClient.entrySet()) {
UUID uuid = entry.getKey();
Set<String> consumers = entry.getValue();
ClientState<TaskId> state = states.get(uuid);

ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTasks.size());
final int numActiveTasks = state.activeTasks.size();
for (TaskId id : state.activeTasks) {
taskIds.add(id);
}
ByteBuffer buf = ByteBuffer.allocate(4 + ids.size() * 8);
//version
buf.putInt(1);
// encode task ids
for (TaskId id : ids) {
buf.putInt(id.topicGroupId);
buf.putInt(id.partition);
for (TaskId id : state.assignedTasks) {
if (!state.activeTasks.contains(id))
taskIds.add(id);
}

final int numConsumers = consumers.size();
List<TaskId> active = new ArrayList<>();
Set<TaskId> standby = new HashSet<>();

int i = 0;
for (String consumer : consumers) {
List<TopicPartition> partitions = new ArrayList<>();

final int numTaskIds = taskIds.size();
for (int j = i; j < numTaskIds; j += numConsumers) {
TaskId taskId = taskIds.get(j);
if (j < numActiveTasks) {
for (TopicPartition partition : partitionGroups.get(taskId)) {
partitions.add(partition);
active.add(taskId);
}
} else {
// no partition to a standby task
standby.add(taskId);
}
}

AssignmentInfo data = new AssignmentInfo(active, standby);
assignment.put(consumer, new Assignment(partitions, data.encode()));
i++;

active.clear();
standby.clear();
}
buf.rewind();
assignment.put(clientIds[i], new Assignment(partitions, buf));
}

return assignment;
Expand All @@ -103,27 +187,29 @@ public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription
@Override
public void onAssignment(Assignment assignment) {
List<TopicPartition> partitions = assignment.partitions();
ByteBuffer data = assignment.userData();
data.rewind();

AssignmentInfo info = AssignmentInfo.decode(assignment.userData());
this.standbyTasks = info.standbyTasks;

Map<TopicPartition, Set<TaskId>> partitionToTaskIds = new HashMap<>();
Iterator<TaskId> iter = info.activeTasks.iterator();
for (TopicPartition partition : partitions) {
Set<TaskId> taskIds = partitionToTaskIds.get(partition);
if (taskIds == null) {
taskIds = new HashSet<>();
partitionToTaskIds.put(partition, taskIds);
}

// check version
int version = data.getInt();
if (version == 1) {
for (TopicPartition partition : partitions) {
Set<TaskId> taskIds = partitionToTaskIds.get(partition);
if (taskIds == null) {
taskIds = new HashSet<>();
partitionToTaskIds.put(partition, taskIds);
}
// decode a task id
taskIds.add(new TaskId(data.getInt(), data.getInt()));
if (iter.hasNext()) {
taskIds.add(iter.next());
} else {
TaskAssignmentException ex = new TaskAssignmentException(
"failed to find a task id for the partition=" + partition.toString() +
", partitions=" + partitions.size() + ", assignmentInfo=" + info.toString()
);
log.error(ex.getMessage(), ex);
throw ex;
}
} else {
KafkaException ex = new KafkaException("unknown assignment data version: " + version);
log.error(ex.getMessage(), ex);
throw ex;
}
this.partitionToTaskIds = partitionToTaskIds;
}
Expand All @@ -132,4 +218,7 @@ public Set<TaskId> taskIds(TopicPartition partition) {
return partitionToTaskIds.get(partition);
}

public Set<TaskId> standbyTasks() {
return standbyTasks;
}
}

0 comments on commit 0d3def0

Please sign in to comment.